Move abs, frac, reciprocal, and neg to TensorIterator (#19041)
authorJames Reed <jamesreed@fb.com>
Wed, 10 Apr 2019 04:48:49 +0000 (21:48 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 10 Apr 2019 04:55:00 +0000 (21:55 -0700)
Summary:
I've been messing around with vectorizing the fusion compiler in JIT, and noticed that these ops were pathologically slow. I moved them to use TensorIterator + Vec256<> and got some speed wins.

Benchmark script:

```
import torch, time

ops = ['abs', 'neg', 'reciprocal', 'frac']

x = torch.rand(1024, 1024)
NITER = 10000

print('op', 'time per iter (ms)', 'gops/s', 'GB/s', sep='\t')

for op in ops:
    s = time.time()
    for i in range(NITER):
        getattr(x, op)()
    elapsed_sec = ((time.time() - s) / NITER)
    print(op, elapsed_sec * 1000, (1024*1024/elapsed_sec)/1e9, (1024*1024*4*2) / elapsed_sec / 1e9, sep='\t')

```

Before this change (on my mac with a skylake):
```
op      time per iter (ms)      gops/s  GB/s
abs     0.9730974197387695      1.0775652866097343      8.620522292877874
neg     1.0723679780960083      0.9778136063534356      7.822508850827485
reciprocal      1.2610594034194946      0.8315040490215421      6.6520323921723366
frac    1.1681334018707275      0.8976509004200546      7.181207203360437
```

After this change:
```
op      time per iter (ms)      gops/s  GB/s
abs     0.5031076192855835      2.084198210889721       16.673585687117768
neg     0.4433974027633667      2.3648672578256087      18.91893806260487
reciprocal      0.47145988941192624     2.2241043693195985      17.79283495455679
frac    0.5036592721939087      2.0819154096627024      16.65532327730162
```

So, after this change it looks like we are hitting machine peak for bandwidth and are bandwidth bound.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19041

Differential Revision: D14862037

Pulled By: jamesr66a

fbshipit-source-id: e2032ac0ca962dbf4120bb36812277c260e22912

15 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/cpu/vec256/vec256_base.h
aten/src/ATen/cpu/vec256/vec256_double.h
aten/src/ATen/cpu/vec256/vec256_float.h
aten/src/ATen/cpu/vec256/vec256_int.h
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/UnaryOps.cpp
aten/src/ATen/native/UnaryOps.h
aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
aten/src/ATen/native/cuda/CUDAUnaryOps.cpp
aten/src/ATen/native/native_functions.yaml
test/test_torch.py

index a10f761..43fb088 100644 (file)
   name: _th_abs
   cname: abs
   backends:
-    - CPU
     - CUDA
   variants: function
   return: argument 0
   types:
     - floating_point
   backends:
-    - CPU
     - CUDA
   cname: frac
   variants: function
   types:
     - floating_point
   backends:
-    - CPU
     - CUDA
   variants:
     - function
   types:
     - floating_point
   backends:
-    - CPU
     - CUDA
   variants:
     - function
   types:
     - floating_point
   backends:
-    - CPU
     - CUDA
   variants: function
   options:
 [[
   name: _th_neg
   backends:
-    - CPU
     - CUDA
   variants:
     - function
 [[
   name: _th_neg_
   backends:
-    - CPU
     - CUDA
   variants: function
   options:
index 009ad2e..6f58a4a 100644 (file)
@@ -408,6 +408,8 @@ class CAFFE2_API Tensor {
   Tensor & fill_(const Tensor & value);
   Tensor floor() const;
   Tensor & floor_();
+  Tensor frac() const;
+  Tensor & frac_();
   Tensor ger(const Tensor & vec2) const;
   Tensor fft(int64_t signal_ndim, bool normalized=false) const;
   Tensor ifft(int64_t signal_ndim, bool normalized=false) const;
@@ -465,6 +467,10 @@ class CAFFE2_API Tensor {
   Tensor permute(IntArrayRef dims) const;
   Tensor pin_memory() const;
   Tensor pinverse(double rcond=1e-15) const;
+  Tensor reciprocal() const;
+  Tensor & reciprocal_();
+  Tensor neg() const;
+  Tensor & neg_();
   Tensor repeat(IntArrayRef repeats) const;
   Tensor repeat_interleave(const Tensor & repeats, c10::optional<int64_t> dim=c10::nullopt) const;
   Tensor repeat_interleave(int64_t repeats, c10::optional<int64_t> dim=c10::nullopt) const;
@@ -648,10 +654,7 @@ class CAFFE2_API Tensor {
   Tensor & digamma_();
   Tensor & polygamma_(int64_t n);
   Tensor & erfinv_();
-  Tensor & frac_();
   Tensor & renorm_(Scalar p, int64_t dim, Scalar maxnorm);
-  Tensor & reciprocal_();
-  Tensor & neg_();
   Tensor & pow_(Scalar exponent);
   Tensor & pow_(const Tensor & exponent);
   Tensor & lerp_(const Tensor & end, Scalar weight);
@@ -718,10 +721,7 @@ class CAFFE2_API Tensor {
   Tensor digamma() const;
   Tensor polygamma(int64_t n) const;
   Tensor erfinv() const;
-  Tensor frac() const;
   Tensor dist(const Tensor & other, Scalar p=2) const;
-  Tensor reciprocal() const;
-  Tensor neg() const;
   Tensor atan2(const Tensor & other) const;
   Tensor lerp(const Tensor & end, Scalar weight) const;
   Tensor lerp(const Tensor & end, const Tensor & weight) const;
index 065af04..26f3807 100644 (file)
@@ -280,6 +280,12 @@ inline Tensor Tensor::floor() const {
 inline Tensor & Tensor::floor_() {
     return dispatch_type().floor_(*this);
 }
+inline Tensor Tensor::frac() const {
+    return dispatch_type().frac(*this);
+}
+inline Tensor & Tensor::frac_() {
+    return dispatch_type().frac_(*this);
+}
 inline Tensor Tensor::ger(const Tensor & vec2) const {
     return dispatch_type().ger(*this, vec2);
 }
@@ -451,6 +457,18 @@ inline Tensor Tensor::pin_memory() const {
 inline Tensor Tensor::pinverse(double rcond) const {
     return dispatch_type().pinverse(*this, rcond);
 }
+inline Tensor Tensor::reciprocal() const {
+    return dispatch_type().reciprocal(*this);
+}
+inline Tensor & Tensor::reciprocal_() {
+    return dispatch_type().reciprocal_(*this);
+}
+inline Tensor Tensor::neg() const {
+    return dispatch_type().neg(*this);
+}
+inline Tensor & Tensor::neg_() {
+    return dispatch_type().neg_(*this);
+}
 inline Tensor Tensor::repeat(IntArrayRef repeats) const {
     return dispatch_type().repeat(*this, repeats);
 }
@@ -1000,18 +1018,9 @@ inline Tensor & Tensor::polygamma_(int64_t n) {
 inline Tensor & Tensor::erfinv_() {
     return dispatch_type().erfinv_(*this);
 }
-inline Tensor & Tensor::frac_() {
-    return dispatch_type().frac_(*this);
-}
 inline Tensor & Tensor::renorm_(Scalar p, int64_t dim, Scalar maxnorm) {
     return dispatch_type().renorm_(*this, p, dim, maxnorm);
 }
-inline Tensor & Tensor::reciprocal_() {
-    return dispatch_type().reciprocal_(*this);
-}
-inline Tensor & Tensor::neg_() {
-    return dispatch_type().neg_(*this);
-}
 inline Tensor & Tensor::pow_(Scalar exponent) {
     return dispatch_type().pow_(*this, exponent);
 }
@@ -1210,18 +1219,9 @@ inline Tensor Tensor::polygamma(int64_t n) const {
 inline Tensor Tensor::erfinv() const {
     return dispatch_type().erfinv(*this);
 }
-inline Tensor Tensor::frac() const {
-    return dispatch_type().frac(*this);
-}
 inline Tensor Tensor::dist(const Tensor & other, Scalar p) const {
     return dispatch_type().dist(*this, other, p);
 }
-inline Tensor Tensor::reciprocal() const {
-    return dispatch_type().reciprocal(*this);
-}
-inline Tensor Tensor::neg() const {
-    return dispatch_type().neg(*this);
-}
 inline Tensor Tensor::atan2(const Tensor & other) const {
     return dispatch_type().atan2(*this, other);
 }
index 97871be..15d5efd 100644 (file)
@@ -282,6 +282,8 @@ struct CAFFE2_API Type {
   virtual Tensor & fill_(Tensor & self, const Tensor & value) const = 0;
   virtual Tensor floor(const Tensor & self) const = 0;
   virtual Tensor & floor_(Tensor & self) const = 0;
+  virtual Tensor frac(const Tensor & self) const = 0;
+  virtual Tensor & frac_(Tensor & self) const = 0;
   virtual Tensor ger(const Tensor & self, const Tensor & vec2) const = 0;
   virtual Tensor fft(const Tensor & self, int64_t signal_ndim, bool normalized) const = 0;
   virtual Tensor ifft(const Tensor & self, int64_t signal_ndim, bool normalized) const = 0;
@@ -339,6 +341,10 @@ struct CAFFE2_API Type {
   virtual Tensor permute(const Tensor & self, IntArrayRef dims) const = 0;
   virtual Tensor pin_memory(const Tensor & self) const = 0;
   virtual Tensor pinverse(const Tensor & self, double rcond) const = 0;
+  virtual Tensor reciprocal(const Tensor & self) const = 0;
+  virtual Tensor & reciprocal_(Tensor & self) const = 0;
+  virtual Tensor neg(const Tensor & self) const = 0;
+  virtual Tensor & neg_(Tensor & self) const = 0;
   virtual Tensor repeat(const Tensor & self, IntArrayRef repeats) const = 0;
   virtual Tensor repeat_interleave(const Tensor & repeats) const = 0;
   virtual Tensor repeat_interleave(const Tensor & self, const Tensor & repeats, c10::optional<int64_t> dim) const = 0;
@@ -523,10 +529,7 @@ struct CAFFE2_API Type {
   virtual Tensor & digamma_(Tensor & self) const = 0;
   virtual Tensor & polygamma_(Tensor & self, int64_t n) const = 0;
   virtual Tensor & erfinv_(Tensor & self) const = 0;
-  virtual Tensor & frac_(Tensor & self) const = 0;
   virtual Tensor & renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) const = 0;
-  virtual Tensor & reciprocal_(Tensor & self) const = 0;
-  virtual Tensor & neg_(Tensor & self) const = 0;
   virtual Tensor & pow_(Tensor & self, Scalar exponent) const = 0;
   virtual Tensor & pow_(Tensor & self, const Tensor & exponent) const = 0;
   virtual Tensor & lerp_(Tensor & self, const Tensor & end, Scalar weight) const = 0;
@@ -593,10 +596,7 @@ struct CAFFE2_API Type {
   virtual Tensor digamma(const Tensor & self) const = 0;
   virtual Tensor polygamma(int64_t n, const Tensor & self) const = 0;
   virtual Tensor erfinv(const Tensor & self) const = 0;
-  virtual Tensor frac(const Tensor & self) const = 0;
   virtual Tensor dist(const Tensor & self, const Tensor & other, Scalar p) const = 0;
-  virtual Tensor reciprocal(const Tensor & self) const = 0;
-  virtual Tensor neg(const Tensor & self) const = 0;
   virtual Tensor atan2(const Tensor & self, const Tensor & other) const = 0;
   virtual Tensor lerp(const Tensor & self, const Tensor & end, Scalar weight) const = 0;
   virtual Tensor lerp(const Tensor & self, const Tensor & end, const Tensor & weight) const = 0;
index b7a0df2..36b1db4 100644 (file)
@@ -194,6 +194,9 @@ public:
   Vec256<T> expm1() const {
     return map(std::expm1);
   }
+  Vec256<T> frac() const {
+    return *this - this->trunc();
+  }
   Vec256<T> log() const {
     return map(std::log);
   }
@@ -219,7 +222,10 @@ public:
     return map(std::floor);
   }
   Vec256<T> neg() const {
-    return map([](T x) { return -x; });
+    // NB: the trailing return type is needed because we need to coerce the
+    // return value back to T in the case of unary operator- incuring a
+    // promotion
+    return map([](T x) -> T { return -x; });
   }
   Vec256<T> round() const {
     return map(std::nearbyint);
index c5fea7d..0c963b7 100644 (file)
@@ -141,6 +141,7 @@ public:
   Vec256<double> floor() const {
     return _mm256_floor_pd(values);
   }
+  Vec256<double> frac() const;
   Vec256<double> neg() const {
     return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
   }
@@ -216,6 +217,11 @@ Vec256<double> inline operator/(const Vec256<double>& a, const Vec256<double>& b
   return _mm256_div_pd(a, b);
 }
 
+// frac. Implement this here so we can use subtraction.
+Vec256<double> Vec256<double>::frac() const {
+  return *this - this->trunc();
+}
+
 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
 // either input is a NaN.
 template <>
index dfaa126..0937e50 100644 (file)
@@ -131,6 +131,7 @@ public:
   Vec256<float> log1p() const {
     return Vec256<float>(Sleef_log1pf8_u10(values));
   }
+  Vec256<float> frac() const;
   Vec256<float> sin() const {
     return map(std::sin);
   }
@@ -224,6 +225,11 @@ Vec256<float> inline operator/(const Vec256<float>& a, const Vec256<float>& b) {
   return _mm256_div_ps(a, b);
 }
 
+// frac. Implement this here so we can use subtraction
+Vec256<float> Vec256<float>::frac() const {
+  return *this - this->trunc();
+}
+
 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
 // either input is a NaN.
 template <>
index bed5dbb..8a22d90 100644 (file)
@@ -96,6 +96,8 @@ struct Vec256<int64_t> : public Vec256i {
     auto inverse = _mm256_xor_si256(values, is_larger);
     return _mm256_sub_epi64(inverse, is_larger);
   }
+  Vec256<int64_t> frac() const;
+  Vec256<int64_t> neg() const;
   Vec256<int64_t> operator==(const Vec256<int64_t>& other) const {
     return _mm256_cmpeq_epi64(values, other.values);
   }
@@ -185,6 +187,8 @@ struct Vec256<int32_t> : public Vec256i {
   Vec256<int32_t> abs() const {
     return _mm256_abs_epi32(values);
   }
+  Vec256<int32_t> frac() const;
+  Vec256<int32_t> neg() const;
   Vec256<int32_t> operator==(const Vec256<int32_t>& other) const {
     return _mm256_cmpeq_epi32(values, other.values);
   }
@@ -369,6 +373,8 @@ struct Vec256<int16_t> : public Vec256i {
   Vec256<int16_t> abs() const {
     return _mm256_abs_epi16(values);
   }
+  Vec256<int16_t> frac() const;
+  Vec256<int16_t> neg() const;
   Vec256<int16_t> operator==(const Vec256<int16_t>& other) const {
     return _mm256_cmpeq_epi16(values, other.values);
   }
@@ -419,6 +425,19 @@ Vec256<int16_t> inline operator-(const Vec256<int16_t>& a, const Vec256<int16_t>
   return _mm256_sub_epi16(a, b);
 }
 
+// Negation. Defined here so we can utilize operator-
+Vec256<int64_t> Vec256<int64_t>::neg() const {
+  return Vec256<int64_t>(0) - *this;
+}
+
+Vec256<int32_t> Vec256<int32_t>::neg() const {
+  return Vec256<int32_t>(0) - *this;
+}
+
+Vec256<int16_t> Vec256<int16_t>::neg() const {
+  return Vec256<int16_t>(0) - *this;
+}
+
 // Emulate operations with no native 64-bit support in avx,
 // by extracting each element, performing the operation pointwise,
 // then combining the results into a vector.
index 60dce0b..a266599 100644 (file)
@@ -162,22 +162,10 @@ Tensor & erfinv_(Tensor& self) {
   return at::legacy::th::_th_erfinv_(self);
 }
 
-Tensor & frac_(Tensor& self) {
-  return at::legacy::th::_th_frac_(self);
-}
-
 Tensor & renorm_(Tensor& self, Scalar p, int64_t dim, Scalar maxnorm) {
   return at::legacy::th::_th_renorm_(self, p, dim, maxnorm);
 }
 
-Tensor & reciprocal_(Tensor& self) {
-  return at::legacy::th::_th_reciprocal_(self);
-}
-
-Tensor & neg_(Tensor& self) {
-  return at::legacy::th::_th_neg_(self);
-}
-
 Tensor & pow_(Tensor& self, Scalar exponent) {
   return at::legacy::th::_th_pow_(self, exponent);
 }
@@ -563,34 +551,10 @@ Tensor erfinv(const Tensor & self) {
   return at::legacy::th::_th_erfinv(self);
 }
 
-Tensor & frac_out(Tensor & result, const Tensor & self) {
-  return at::legacy::th::_th_frac_out(result, self);
-}
-
-Tensor frac(const Tensor & self) {
-  return at::legacy::th::_th_frac(self);
-}
-
 Tensor dist(const Tensor & self, const Tensor & other, Scalar p) {
   return at::legacy::th::_th_dist(self, other, p);
 }
 
-Tensor & reciprocal_out(Tensor & result, const Tensor & self) {
-  return at::legacy::th::_th_reciprocal_out(result, self);
-}
-
-Tensor reciprocal(const Tensor & self) {
-  return at::legacy::th::_th_reciprocal(self);
-}
-
-Tensor & neg_out(Tensor & result, const Tensor & self) {
-  return at::legacy::th::_th_neg_out(result, self);
-}
-
-Tensor neg(const Tensor & self) {
-  return at::legacy::th::_th_neg(self);
-}
-
 Tensor & atan2_out(Tensor & result, const Tensor & self, const Tensor & other) {
   return at::legacy::th::_th_atan2_out(result, self, other);
 }
index d690ba6..9121a36 100644 (file)
@@ -115,7 +115,6 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
   return self.copy_(args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(M_PI) / 4.));
 }
 
-
 Tensor sigmoid(const Tensor& self) {
   Tensor result = at::empty({0}, self.options());
   return at::sigmoid_out(result, self);
@@ -167,7 +166,7 @@ Tensor& _sigmoid_out_cpu(Tensor& result, const Tensor& self) {
 
 // NB: Temp. defaulting to TH implementation of abs due to issues with Apple
 
-IMPLEMENT_UNARY_OP_TH(abs)
+IMPLEMENT_UNARY_OP_VEC(abs)
 IMPLEMENT_UNARY_OP_VEC(acos)
 IMPLEMENT_UNARY_OP_VEC(asin)
 IMPLEMENT_UNARY_OP_VEC(atan)
@@ -179,10 +178,13 @@ IMPLEMENT_UNARY_OP_VEC(erfc)
 IMPLEMENT_UNARY_OP_VEC(exp)
 IMPLEMENT_UNARY_OP_VEC(expm1)
 IMPLEMENT_UNARY_OP_VEC(floor)
+IMPLEMENT_UNARY_OP_VEC(frac)
 IMPLEMENT_UNARY_OP_VEC(log)
 IMPLEMENT_UNARY_OP_VEC(log10)
 IMPLEMENT_UNARY_OP_VEC(log1p)
 IMPLEMENT_UNARY_OP_VEC(log2)
+IMPLEMENT_UNARY_OP_VEC(neg)
+IMPLEMENT_UNARY_OP_VEC(reciprocal)
 IMPLEMENT_UNARY_OP_VEC(round)
 IMPLEMENT_UNARY_OP_VEC(rsqrt)
 IMPLEMENT_UNARY_OP_VEC(sin)
@@ -203,10 +205,13 @@ DEFINE_DISPATCH(erfc_stub);
 DEFINE_DISPATCH(exp_stub);
 DEFINE_DISPATCH(expm1_stub);
 DEFINE_DISPATCH(floor_stub);
+DEFINE_DISPATCH(frac_stub);
 DEFINE_DISPATCH(log_stub);
 DEFINE_DISPATCH(log10_stub);
 DEFINE_DISPATCH(log1p_stub);
 DEFINE_DISPATCH(log2_stub);
+DEFINE_DISPATCH(neg_stub);
+DEFINE_DISPATCH(reciprocal_stub);
 DEFINE_DISPATCH(round_stub);
 DEFINE_DISPATCH(rsqrt_stub);
 DEFINE_DISPATCH(sigmoid_stub);
index e60bda8..b6758ca 100644 (file)
@@ -23,10 +23,13 @@ DECLARE_DISPATCH(unary_fn, erfc_stub);
 DECLARE_DISPATCH(unary_fn, exp_stub);
 DECLARE_DISPATCH(unary_fn, expm1_stub);
 DECLARE_DISPATCH(unary_fn, floor_stub);
+DECLARE_DISPATCH(unary_fn, frac_stub);
 DECLARE_DISPATCH(unary_fn, log_stub);
 DECLARE_DISPATCH(unary_fn, log10_stub);
 DECLARE_DISPATCH(unary_fn, log1p_stub);
 DECLARE_DISPATCH(unary_fn, log2_stub);
+DECLARE_DISPATCH(unary_fn, neg_stub);
+DECLARE_DISPATCH(unary_fn, reciprocal_stub);
 DECLARE_DISPATCH(unary_fn, round_stub);
 DECLARE_DISPATCH(unary_fn, rsqrt_stub);
 DECLARE_DISPATCH(unary_fn, sigmoid_stub);
@@ -44,12 +47,9 @@ DECLARE_DISPATCH(void(*)(Tensor&, const double, Generator *), bernoulli_mkl_stub
 // lgamma
 // erfinv
 // fill
-// frac
 // clone
 // contiguous
 // clamp/_min/_max
-// neg
-// reciprocal
 // sign
 // zero
 }} // namespace at::native
index c6d309c..9ede538 100644 (file)
@@ -45,6 +45,42 @@ static void sigmoid_kernel(TensorIterator& iter) {
   });
 }
 
+static void abs_kernel(TensorIterator& iter) {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "abs_cpu", [&]() {
+    unary_kernel_vec(
+        iter,
+        [=](scalar_t a) -> scalar_t { return std::abs(a); },
+        [=](Vec256<scalar_t> a) { return a.abs(); });
+  });
+}
+
+static void frac_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "frac_cpu", [&]() {
+    unary_kernel_vec(
+        iter,
+        [=](scalar_t a) -> scalar_t { return a - std::trunc(a); },
+        [=](Vec256<scalar_t> a) { return a.frac(); });
+  });
+}
+
+static void reciprocal_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "reciprocal_cpu", [&]() {
+    unary_kernel_vec(
+        iter,
+        [=](scalar_t a) -> scalar_t { return decltype(a)(1.0) / a; },
+        [=](Vec256<scalar_t> a) { return a.reciprocal(); });
+  });
+}
+
+static void neg_kernel(TensorIterator& iter) {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "neg_cpu", [&]() {
+    unary_kernel_vec(
+        iter,
+        [=](scalar_t a) -> scalar_t { return -a; },
+        [=](Vec256<scalar_t> a) { return a.neg(); });
+  });
+}
+
 #if !AT_MKL_ENABLED()
 void bernoulli_mkl_kernel(Tensor &output, const double p, Generator* gen) {
   // Use AT_ASSERTM because this should never be reached, and AT_ASSERTM tells
@@ -152,6 +188,10 @@ static void rsqrt_kernel(TensorIterator& iter) {
 REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel)
 REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel)
 REGISTER_DISPATCH(bernoulli_mkl_stub, &bernoulli_mkl_kernel);
+REGISTER_DISPATCH(abs_stub, &abs_kernel);
+REGISTER_DISPATCH(frac_stub, &frac_kernel);
+REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel);
+REGISTER_DISPATCH(neg_stub, &neg_kernel);
 
 // IMPLEMENT_FLOAT_KERNEL(ALL, abs)
 IMPLEMENT_FLOAT_KERNEL(FLOATING, acos)
index 0f80998..73dccdf 100644 (file)
@@ -61,11 +61,14 @@ IMPLEMENT_UNARY_OP_PREQUEL(erf)
 IMPLEMENT_UNARY_OP_PREQUEL(erfc)
 IMPLEMENT_UNARY_OP_PREQUEL(exp)
 IMPLEMENT_UNARY_OP_PREQUEL(expm1)
+IMPLEMENT_UNARY_OP_PREQUEL(frac)
 IMPLEMENT_UNARY_OP_PREQUEL(floor)
 IMPLEMENT_UNARY_OP_PREQUEL(log)
 IMPLEMENT_UNARY_OP_PREQUEL(log10)
 IMPLEMENT_UNARY_OP_PREQUEL(log1p)
 IMPLEMENT_UNARY_OP_PREQUEL(log2)
+IMPLEMENT_UNARY_OP_PREQUEL(neg)
+IMPLEMENT_UNARY_OP_PREQUEL(reciprocal)
 IMPLEMENT_UNARY_OP_PREQUEL(round)
 IMPLEMENT_UNARY_OP_PREQUEL(rsqrt)
 IMPLEMENT_UNARY_OP_PREQUEL(sigmoid)
index 43ece56..1499071 100644 (file)
     CPU: _floor_out_cpu
     CUDA: _floor_out_cuda
 
+- func: frac(Tensor self) -> Tensor
+  matches_jit_signature: True
+  variants: function, method
+
+- func: frac_(Tensor(a!) self) -> Tensor(a!)
+  matches_jit_signature: True
+  variants: function, method
+  dispatch:
+    CPU: _frac__cpu
+    CUDA: _frac__cuda
+
+- func: frac(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  matches_jit_signature: True
+  dispatch:
+    CPU: _frac_out_cpu
+    CUDA: _frac_out_cuda
+
 - func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
   matches_jit_signature: True
 
     CPU: range_cpu_out
     CUDA: range_cuda_out
 
+- func: reciprocal(Tensor self) -> Tensor
+  matches_jit_signature: True
+  variants: function, method
+
+- func: reciprocal_(Tensor(a!) self) -> Tensor(a!)
+  matches_jit_signature: True
+  variants: function, method
+  dispatch:
+    CPU: _reciprocal__cpu
+    CUDA: _reciprocal__cuda
+
+- func: reciprocal(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  matches_jit_signature: True
+  dispatch:
+    CPU: _reciprocal_out_cpu
+    CUDA: _reciprocal_out_cuda
+
+- func: neg(Tensor self) -> Tensor
+  matches_jit_signature: True
+  variants: function, method
+
+- func: neg_(Tensor(a!) self) -> Tensor(a!)
+  matches_jit_signature: True
+  variants: function, method
+  dispatch:
+    CPU: _neg__cpu
+    CUDA: _neg__cuda
+
+- func: neg(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  matches_jit_signature: True
+  dispatch:
+    CPU: _neg_out_cpu
+    CUDA: _neg_out_cuda
+
 - func: repeat(Tensor self, int[] repeats) -> Tensor
   matches_jit_signature: True
   variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
   matches_jit_signature: True
   variants: method
 
-- func: frac_(Tensor(a!) self) -> Tensor(a!)
-  matches_jit_signature: True
-  variants: method
-
 - func: renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)
   matches_jit_signature: True
   variants: method
 
-- func: reciprocal_(Tensor(a!) self) -> Tensor(a!)
-  matches_jit_signature: True
-  variants: method
-
-- func: neg_(Tensor(a!) self) -> Tensor(a!)
-  matches_jit_signature: True
-  variants: method
-
 - func: pow_(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
   matches_jit_signature: True
   variants: method
   matches_jit_signature: True
   variants: method, function
 
-- func: frac(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
-  matches_jit_signature: True
-
-- func: frac(Tensor self) -> Tensor
-  matches_jit_signature: True
-  variants: method, function
-
 - func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
   matches_jit_signature: True
   variants: method, function
 
-- func: reciprocal(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
-  matches_jit_signature: True
-
-- func: reciprocal(Tensor self) -> Tensor
-  matches_jit_signature: True
-  variants: method, function
-
-- func: neg(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
-  matches_jit_signature: True
-
-- func: neg(Tensor self) -> Tensor
-  matches_jit_signature: True
-  variants: method, function
-
 - func: atan2(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
index 6903fdd..2a4c576 100644 (file)
@@ -1549,11 +1549,12 @@ class _TestTorchMixin(object):
                 self.assertTrue(y.le(0).any())
 
     def test_reciprocal(self):
-        a = torch.randn(100, 89)
-        res_div = 1 / a
-        res_reciprocal = a.clone()
-        res_reciprocal.reciprocal_()
-        self.assertEqual(res_reciprocal, res_div)
+        for dtype in [torch.float, torch.double]:
+            a = torch.randn(100, 89, dtype=dtype)
+            res_div = 1 / a
+            res_reciprocal = a.clone()
+            res_reciprocal.reciprocal_()
+            self.assertEqual(res_reciprocal, res_div)
 
     def test_mul(self):
         m1 = torch.randn(10, 10)