Use TensorIterator for unary operations
authorChristian Puhrsch <cpuhrsch@fb.com>
Tue, 26 Mar 2019 16:19:51 +0000 (09:19 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 16:22:52 +0000 (09:22 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18309

Differential Revision: D14591533

Pulled By: cpuhrsch

fbshipit-source-id: a3b0788a481bddf1803c9f2d3289263d7364f8d7

aten/src/ATen/CPUApplyUtils.h
aten/src/ATen/native/Copy.cpp
aten/src/ATen/native/Distributions.cpp
aten/src/ATen/native/TensorIterator.cpp
aten/src/ATen/native/TensorIterator.h
aten/src/ATen/native/UnaryOps.cpp
aten/src/ATen/native/UnaryOps.h [new file with mode: 0644]
aten/src/ATen/native/cpu/Loops.h
aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
aten/src/ATen/native/cpu/UnaryOpsKernel.h [deleted file]

index 7aadeeb..20753dc 100644 (file)
@@ -498,37 +498,6 @@ inline void CPU_tensor_apply4(
   }
 }
 
-template <typename scalar1, typename Op>
-inline void CPU_tensor_parallel_apply1(
-    Tensor tensor1,
-    const Op op,
-    int64_t grain_size = internal::GRAIN_SIZE) {
-  if (!_apply_preamble({tensor1}))
-    return;
-  if (tensor1.ndimension() < 8) {
-    parallel_for(
-        0,
-        tensor1.numel(),
-        grain_size,
-        [&tensor1, &op](int64_t begin, int64_t end) {
-          apply_op(
-              end - begin,
-              begin,
-              op,
-              strided_tensor_iter_fixed<scalar1, 8>(tensor1, true));
-        });
-  } else {
-    parallel_for(
-        0,
-        tensor1.numel(),
-        grain_size,
-        [&tensor1, &op](int64_t begin, int64_t end) {
-          apply_op(
-              end - begin, begin, op, strided_tensor_iter<scalar1>(tensor1));
-        });
-  }
-}
-
 template <typename scalar1, typename scalar2, typename Op>
 inline void CPU_tensor_parallel_apply2(
     Tensor tensor1,
index 0ff0b0f..e96fa1c 100644 (file)
@@ -105,6 +105,7 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) {
     return;
   }
 
+  // TODO: Replace this with TensorIterator!
   bool serial_path = false;
   if (self.numel() == src.numel()) {
     if (self.is_contiguous() && src.is_contiguous()) {
index 24392ee..9d7c3aa 100644 (file)
@@ -11,7 +11,7 @@
 #include <ATen/core/Generator.h>
 #include <ATen/native/Distributions.h>
 #include <ATen/native/DispatchStub.h>
-#include <ATen/native/cpu/UnaryOpsKernel.h>
+#include <ATen/native/UnaryOps.h>
 
 #include <type_traits>
 #include <functional>
index c588614..b47caca 100644 (file)
@@ -479,6 +479,13 @@ std::unique_ptr<TensorIterator> TensorIterator::binary_op(Tensor& out, const Ten
   return builder.build();
 }
 
+std::unique_ptr<TensorIterator> TensorIterator::unary_op(Tensor& out, const Tensor& a) {
+  auto builder = TensorIterator::Builder();
+  builder.add_output(out);
+  builder.add_input(a);
+  return builder.build();
+}
+
 std::unique_ptr<TensorIterator> TensorIterator::reduce_op(Tensor& out, const Tensor& a) {
   AT_ASSERT(out.defined());
   auto builder = TensorIterator::Builder();
index c4a1a0b..f3510ab 100644 (file)
@@ -126,6 +126,7 @@ struct CAFFE2_API TensorIterator {
   void foreach_reduced_elt(const loop_subiter_t& loop, bool parallelize=true);
 
   static std::unique_ptr<TensorIterator> binary_op(Tensor& out, const Tensor& a, const Tensor& b);
+  static std::unique_ptr<TensorIterator> unary_op(Tensor& out, const Tensor& a);
   static std::unique_ptr<TensorIterator> reduce_op(Tensor& out, const Tensor& a);
 
   int ndim() const { return shape_.size(); }
index 1dccca4..d690ba6 100644 (file)
@@ -9,11 +9,13 @@
 #include <ATen/ExpandUtils.h>
 #include <ATen/NativeFunctions.h>
 #include <ATen/LegacyTHFunctions.h>
+#include <ATen/MemoryOverlap.h>
 #include <ATen/WrapDimUtils.h>
 
 #include <ATen/CPUApplyUtils.h>
 #include <ATen/Parallel.h>
-#include <ATen/native/cpu/UnaryOpsKernel.h>
+#include <ATen/native/UnaryOps.h>
+#include <ATen/native/TensorIterator.h>
 
 #include <algorithm>
 #include <cmath>
@@ -113,6 +115,22 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
   return self.copy_(args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(M_PI) / 4.));
 }
 
+
+Tensor sigmoid(const Tensor& self) {
+  Tensor result = at::empty({0}, self.options());
+  return at::sigmoid_out(result, self);
+}
+Tensor& _sigmoid__cpu(Tensor& self) {
+  return at::sigmoid_out(self, self);
+}
+Tensor& _sigmoid_out_cpu(Tensor& result, const Tensor& self) {
+  checkBackend("sigmoid", {result}, Backend::CPU);
+  assert_no_internal_overlap(result, "sigmoid");
+  auto iter = TensorIterator::unary_op(result, self);
+  sigmoid_stub(iter->device_type(), *iter);
+  return result;
+}
+
 // NB: If you use this macro, you may also need to add a CUDA forwarding
 // stub in CUDAUnaryOps
 
@@ -121,18 +139,14 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
     Tensor result = at::empty({0}, self.options());             \
     return at::op##_out(result, self);                          \
   }                                                             \
-  Tensor& _##op##__cpu(Tensor& self_) {                         \
-    if (self_.numel() > 0) {                                    \
-      Tensor self = sort_strides(self_);                        \
-      op##Impl(kCPU, self, self);                               \
-    }                                                           \
-    return self_;                                               \
+  Tensor& _##op##__cpu(Tensor& self) {                          \
+    return at::op##_out(self, self);                            \
   }                                                             \
   Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
-    result.resize_(self.sizes());                               \
-    if (result.numel() > 0) {                                   \
-      op##Impl(kCPU, result, self);                             \
-    }                                                           \
+    checkBackend(#op, {result}, Backend::CPU);                  \
+    assert_no_internal_overlap(result, #op);                    \
+    auto iter = TensorIterator::unary_op(result, self);         \
+    op##_stub(iter->device_type(), *iter);                      \
     return result;                                              \
   }
 
@@ -145,8 +159,10 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
     return at::op##_out(self, self);                            \
   }                                                             \
   Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
+    checkBackend(#op, {result}, Backend::CPU);                  \
+    assert_no_internal_overlap(result, #op);                    \
     result.resize_(self.sizes());                               \
-    return at::legacy::th::_th_##op##_out(result, self);                    \
+    return at::legacy::th::_th_##op##_out(result, self);        \
   }
 
 // NB: Temp. defaulting to TH implementation of abs due to issues with Apple
@@ -169,7 +185,6 @@ IMPLEMENT_UNARY_OP_VEC(log1p)
 IMPLEMENT_UNARY_OP_VEC(log2)
 IMPLEMENT_UNARY_OP_VEC(round)
 IMPLEMENT_UNARY_OP_VEC(rsqrt)
-IMPLEMENT_UNARY_OP_VEC(sigmoid)
 IMPLEMENT_UNARY_OP_VEC(sin)
 IMPLEMENT_UNARY_OP_TH(sinh)
 IMPLEMENT_UNARY_OP_VEC(sqrt)
@@ -177,29 +192,29 @@ IMPLEMENT_UNARY_OP_VEC(tan)
 IMPLEMENT_UNARY_OP_VEC(tanh)
 IMPLEMENT_UNARY_OP_VEC(trunc)
 
-DEFINE_DISPATCH(absImpl);
-DEFINE_DISPATCH(acosImpl);
-DEFINE_DISPATCH(asinImpl);
-DEFINE_DISPATCH(atanImpl);
-DEFINE_DISPATCH(ceilImpl);
-DEFINE_DISPATCH(cosImpl);
-DEFINE_DISPATCH(erfImpl);
-DEFINE_DISPATCH(erfcImpl);
-DEFINE_DISPATCH(expImpl);
-DEFINE_DISPATCH(expm1Impl);
-DEFINE_DISPATCH(floorImpl);
-DEFINE_DISPATCH(logImpl);
-DEFINE_DISPATCH(log10Impl);
-DEFINE_DISPATCH(log1pImpl);
-DEFINE_DISPATCH(log2Impl);
-DEFINE_DISPATCH(roundImpl);
-DEFINE_DISPATCH(rsqrtImpl);
-DEFINE_DISPATCH(sigmoidImpl);
-DEFINE_DISPATCH(sinImpl);
-DEFINE_DISPATCH(sqrtImpl);
-DEFINE_DISPATCH(tanImpl);
-DEFINE_DISPATCH(tanhImpl);
-DEFINE_DISPATCH(truncImpl);
+DEFINE_DISPATCH(abs_stub);
+DEFINE_DISPATCH(acos_stub);
+DEFINE_DISPATCH(asin_stub);
+DEFINE_DISPATCH(atan_stub);
+DEFINE_DISPATCH(ceil_stub);
+DEFINE_DISPATCH(cos_stub);
+DEFINE_DISPATCH(erf_stub);
+DEFINE_DISPATCH(erfc_stub);
+DEFINE_DISPATCH(exp_stub);
+DEFINE_DISPATCH(expm1_stub);
+DEFINE_DISPATCH(floor_stub);
+DEFINE_DISPATCH(log_stub);
+DEFINE_DISPATCH(log10_stub);
+DEFINE_DISPATCH(log1p_stub);
+DEFINE_DISPATCH(log2_stub);
+DEFINE_DISPATCH(round_stub);
+DEFINE_DISPATCH(rsqrt_stub);
+DEFINE_DISPATCH(sigmoid_stub);
+DEFINE_DISPATCH(sin_stub);
+DEFINE_DISPATCH(sqrt_stub);
+DEFINE_DISPATCH(tan_stub);
+DEFINE_DISPATCH(tanh_stub);
+DEFINE_DISPATCH(trunc_stub);
 
 }
 } // namespace at
diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h
new file mode 100644 (file)
index 0000000..e60bda8
--- /dev/null
@@ -0,0 +1,55 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/native/DispatchStub.h>
+#include <ATen/Generator.h>
+#include <stdexcept>
+
+namespace at { struct TensorIterator; }
+
+namespace at { namespace native {
+
+using unary_fn = void(*)(TensorIterator&);
+
+DECLARE_DISPATCH(unary_fn, abs_stub);
+DECLARE_DISPATCH(unary_fn, acos_stub);
+DECLARE_DISPATCH(unary_fn, asin_stub);
+DECLARE_DISPATCH(unary_fn, atan_stub);
+DECLARE_DISPATCH(unary_fn, ceil_stub);
+DECLARE_DISPATCH(unary_fn, cos_stub);
+// DECLARE_DISPATCH(unary_fn, cosh_stub);
+DECLARE_DISPATCH(unary_fn, erf_stub);
+DECLARE_DISPATCH(unary_fn, erfc_stub);
+DECLARE_DISPATCH(unary_fn, exp_stub);
+DECLARE_DISPATCH(unary_fn, expm1_stub);
+DECLARE_DISPATCH(unary_fn, floor_stub);
+DECLARE_DISPATCH(unary_fn, log_stub);
+DECLARE_DISPATCH(unary_fn, log10_stub);
+DECLARE_DISPATCH(unary_fn, log1p_stub);
+DECLARE_DISPATCH(unary_fn, log2_stub);
+DECLARE_DISPATCH(unary_fn, round_stub);
+DECLARE_DISPATCH(unary_fn, rsqrt_stub);
+DECLARE_DISPATCH(unary_fn, sigmoid_stub);
+DECLARE_DISPATCH(unary_fn, sin_stub);
+// DECLARE_DISPATCH(unary_fn, sinh_stub);
+DECLARE_DISPATCH(unary_fn, sqrt_stub);
+DECLARE_DISPATCH(unary_fn, tan_stub);
+DECLARE_DISPATCH(unary_fn, tanh_stub);
+DECLARE_DISPATCH(unary_fn, trunc_stub);
+
+DECLARE_DISPATCH(void(*)(Tensor&, const double, Generator *), bernoulli_mkl_stub);
+
+// Missing unary functions
+// digamma
+// lgamma
+// erfinv
+// fill
+// frac
+// clone
+// contiguous
+// clamp/_min/_max
+// neg
+// reciprocal
+// sign
+// zero
+}} // namespace at::native
index 62b7f56..734c2fe 100644 (file)
@@ -33,6 +33,13 @@ static inline bool is_binary_contiguous_s2(const int64_t* strides) {
          strides[2] == 0;
 }
 
+// all two operands contiguous
+template <typename traits>
+static inline bool is_unary_contiguous(const int64_t* strides) {
+  return strides[0] == sizeof(typename traits::result_type) &&
+         strides[1] == sizeof(typename traits::arg1_t);
+}
+
 // result is
 static inline bool is_reduction(char** data, const int64_t* strides) {
   return strides[0] == 0 &&
@@ -40,6 +47,24 @@ static inline bool is_reduction(char** data, const int64_t* strides) {
          data[0] == data[1];
 }
 
+#define UNARY_LOOP_HEADER(func_t, data, strides) \
+  using traits = unary_function_traits<func_t>; \
+  using arg0_t = typename traits::result_type; \
+  using arg1_t = typename traits::arg1_t; \
+  char* out_ptr = data[0]; \
+  const char* in1_ptr = data[1]; \
+  int64_t s0 = strides[0], s1 = strides[1];
+
+#define UNARY_VEC_HEADER(func_t) \
+  using traits = unary_function_traits<func_t>; \
+  using scalar_t = typename traits::result_type; \
+  using Vec = Vec256<scalar_t>;
+
+#define UNARY_VEC_LOOP_HEADER(func_t, data) \
+  UNARY_VEC_HEADER(func_t) \
+  char* out_ptr = data[0]; \
+  const char* in1_ptr = data[1];
+
 #define LOOP_HEADER(func_t, data, strides) \
   using traits = binary_function_traits<func_t>; \
   using arg0_t = typename traits::result_type; \
@@ -61,6 +86,34 @@ static inline bool is_reduction(char** data, const int64_t* strides) {
   const char* in1_ptr = data[1]; \
   const char* in2_ptr = data[2];
 
+// Basic loop unary operation (one input, one output). May be auto-vectorized
+// by the compiler.
+template <typename func_t>
+static inline void unary_loop(char** data, const int64_t* strides, int64_t i, int64_t n, func_t op) {
+  UNARY_LOOP_HEADER(func_t, data, strides)
+  for (; i < n; i++) {
+    arg1_t in1 = *(arg1_t*)(in1_ptr + i * s1);
+    arg0_t out = op(in1);
+    *(arg0_t*)(out_ptr + i * s0) = out;
+  }
+}
+
+// computes out = op(in1)
+template <typename func_t, typename vec_func_t>
+static inline void vectorized_unary_loop(char** data, int64_t n, func_t op, vec_func_t vop) {
+  UNARY_VEC_LOOP_HEADER(func_t, data)
+  int64_t i = 0;
+  for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
+    auto a1 = Vec::loadu(in1_ptr + i * sizeof(scalar_t));
+    auto a2 = Vec::loadu(in1_ptr + (i + Vec::size()) * sizeof(scalar_t));
+    auto out1 = vop(a1);
+    auto out2 = vop(a2);
+    out1.store(out_ptr + i * sizeof(scalar_t));
+    out2.store(out_ptr + (i + Vec::size()) * sizeof(scalar_t));
+  }
+  int64_t strides[] = { sizeof(scalar_t), sizeof(scalar_t) };
+  unary_loop(data, strides, i, n, op);
+}
 
 // Basic loop binary operation (two inputs, one output). May be auto-vectorized
 // by the compiler.
@@ -209,6 +262,36 @@ static inline void vectorized_outer_reduction(char** data, int64_t inner_stride,
 }
 
 template <typename func_t>
+void unary_kernel(TensorIterator& iter, func_t op) {
+  using traits = unary_function_traits<func_t>;
+
+  iter.for_each([&](int ntensor, char** data, const int64_t* strides, int64_t n) {
+    // Specializations to encourage auto-vectorization (trick from Numpy's loops.c.src)
+    if (is_unary_contiguous<traits>(strides)) {
+      unary_loop(data, strides, 0, n, op);
+    } else {
+      unary_loop(data, strides, 0, n, op);
+    }
+  });
+}
+
+template <typename func_t, typename vec_func_t>
+void unary_kernel_vec(TensorIterator& iter, func_t op, vec_func_t vop) {
+  using traits = unary_function_traits<func_t>;
+  static_assert(
+    std::is_same<typename traits::result_type, typename traits::arg1_t>::value,
+    "all types must match");
+
+  iter.for_each([&](int ntensor, char** data, const int64_t* strides, int64_t n) {
+    if (is_unary_contiguous<traits>(strides)) {
+      vectorized_unary_loop(data, n, op, vop);
+    } else {
+      unary_loop(data, strides, 0, n, op);
+    }
+  });
+}
+
+template <typename func_t>
 void binary_kernel(TensorIterator& iter, func_t op) {
   using traits = binary_function_traits<func_t>;
 
index f53733b..c6d309c 100644 (file)
@@ -1,5 +1,3 @@
-#include <ATen/native/cpu/UnaryOpsKernel.h>
-
 #include <cmath>
 #include <type_traits>
 #include <ATen/Config.h>
@@ -7,14 +5,18 @@
 #include <ATen/CPUGenerator.h>
 #include <ATen/CheckGenerator.h>
 #include <ATen/Generator.h>
-#include <ATen/MemoryOverlap.h>
+#include <ATen/Parallel.h>
+
 #include <ATen/cpu/vml.h>
-#include <ATen/CPUApplyUtils.h>
-#include <ATen/native/DispatchStub.h>
+#include <ATen/cpu/vec256/vec256.h>
+#include <ATen/cpu/vec256/functional.h>
+
 #include <ATen/native/Distributions.h>
-#ifdef __AVX2__
-#include <ATen/native/cpu/avx_mathfun.h>
-#endif
+#include <ATen/native/TensorIterator.h>
+#include <ATen/native/UnaryOps.h>
+
+#include <ATen/native/cpu/Loops.h>
+
 
 #if AT_MKL_ENABLED()
 #include <mkl.h>
@@ -28,95 +30,21 @@ namespace {
 
 using namespace vec256;
 
-template <typename scalar_t>
-static int64_t _sigmoid(scalar_t* x, scalar_t* y, int64_t size);
-
-// This should be a temporary solution until we understand why SLEEF is slower
-// for sigmoid
-
-template <>
-int64_t _sigmoid(float* x, float* y, int64_t size) {
-  using Vec = Vec256<float>;
-  int64_t i = 0;
-  for (; i < size - (size % (2 * Vec::size())); i += 2 * Vec::size()) {
-    Vec ret = Vec::loadu(y + i);
-    Vec ret2 = Vec::loadu(y + i + Vec::size());
-    ret = ret.neg();
-    ret2 = ret2.neg();
-#if defined(__AVX2__) && !defined(_MSC_VER)
-    ret = exp256_ps(ret);
-    ret2 = exp256_ps(ret2);
-#else
-    ret = ret.exp();
-    ret2 = ret2.exp();
-#endif
-    ret = Vec((float)(1)) + ret;
-    ret2 = Vec((float)(1)) + ret2;
-    ret = ret.reciprocal();
-    ret2 = ret2.reciprocal();
-    ret.store(x + i);
-    ret2.store(x + i + Vec::size());
-  }
-  return i;
-}
-
-template <>
-int64_t _sigmoid(double* x, double* y, int64_t size) {
-  using Vec = Vec256<double>;
-  int64_t i = 0;
-  for (; i < size - (size % (2 * Vec::size())); i += 2 * Vec::size()) {
-    Vec ret = Vec::loadu(y + i);
-    Vec ret2 = Vec::loadu(y + i + Vec::size());
-    ret = ret.neg();
-    ret2 = ret2.neg();
-    ret = ret.exp();
-    ret2 = ret2.exp();
-    ret = Vec((double)(1)) + ret;
-    ret2 = Vec((double)(1)) + ret2;
-    ret = ret.reciprocal();
-    ret2 = ret2.reciprocal();
-    ret.store(x + i);
-    ret2.store(x + i + Vec::size());
-  }
-  return i;
-}
-
-static void sigmoid_kernel(Tensor& result, const Tensor& self) {
-  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "sigmoid", [&] {
-    using Vec = Vec256<scalar_t>;
-    CPU_tensor_parallel_kernel_apply2<scalar_t, scalar_t>(
-        result,
-        self,
-        [](int64_t size,
-           scalar_t* x,
-           scalar_t* y,
-           int64_t stridex,
-           int64_t stridey) {
-          int64_t i = 0;
-          if (stridex == 1 && stridey == 1) {
-            i = _sigmoid(x, y, size);
-          }
-          for (; i < size; i += Vec::size()) {
-            scalar_t buffer[Vec::size()];
-            int64_t width = Vec::size();
-            width = std::min(width, size - i);
-            for (int64_t j = 0; j < width; j++) {
-              buffer[j] = y[stridey * (i + j)];
-            }
-            Vec ret = Vec::loadu(buffer);
-            ret = Vec((scalar_t)(0)) - ret;
-            ret = ret.exp();
-            ret = Vec((scalar_t)(1)) + ret;
-            ret = ret.reciprocal();
-            ret.store(buffer);
-            for (int64_t j = 0; j < width; j++)
-              x[stridex * (i + j)] = buffer[j];
-          }
+static void sigmoid_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "sigmoid_cpu", [&]() {
+    unary_kernel_vec(
+        iter,
+        [=](scalar_t a) -> scalar_t { return (1 / (1 + std::exp((-a)))); },
+        [=](Vec256<scalar_t> a) {
+          a = Vec256<scalar_t>((scalar_t)(0)) - a;
+          a = a.exp();
+          a = Vec256<scalar_t>((scalar_t)(1)) + a;
+          a = a.reciprocal();
+          return a;
         });
   });
 }
 
-
 #if !AT_MKL_ENABLED()
 void bernoulli_mkl_kernel(Tensor &output, const double p, Generator* gen) {
   // Use AT_ASSERTM because this should never be reached, and AT_ASSERTM tells
@@ -175,47 +103,54 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) {
 }
 #endif
 
-#define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op)                          \
-  static void op##_kernel(Tensor& result, const Tensor& self) {            \
-    checkBackend(#op, {result}, Backend::CPU);                             \
-    AT_DISPATCH_##dispatchtypes##_TYPES(self.scalar_type(), #op, [&] {     \
-      if (self.is_contiguous() && result.is_contiguous()) {                \
-        vml::v##op(                                                        \
-            result.data<scalar_t>(), self.data<scalar_t>(), self.numel()); \
-                                                                           \
-      } else {                                                             \
-        assert_no_internal_overlap(result, #op);                           \
-        static constexpr int64_t WIDTH = 131072 / sizeof(scalar_t);        \
-        CPU_tensor_parallel_kernel_apply2<scalar_t, scalar_t>(             \
-            result,                                                        \
-            self,                                                          \
-            [](int64_t size,                                               \
-               scalar_t* x,                                                \
-               scalar_t* y,                                                \
-               int64_t stridex,                                            \
-               int64_t stridey) {                                          \
-              if (stridex == 1 && stridey == 1) {                          \
-                vml::v##op(x, y, size);                                    \
-              } else {                                                     \
-                for (int64_t i = 0; i < size; i += WIDTH) {                \
-                  scalar_t buffer[WIDTH];                                  \
-                  int64_t width = WIDTH;                                   \
-                  width = std::min(width, size - i);                       \
-                  for (int64_t j = 0; j < width; j++)                      \
-                    buffer[j] = y[stridey * (i + j)];                      \
-                  vml::v##op(buffer, buffer, width);                       \
-                  for (int64_t j = 0; j < width; j++)                      \
-                    x[stridex * (i + j)] = buffer[j];                      \
-                }                                                          \
-              }                                                            \
-            });                                                            \
-      }                                                                    \
-    });                                                                    \
-  }                                                                        \
-  REGISTER_DISPATCH(op##Impl, &op##_kernel)
+static void rsqrt_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "rsqrt_cpu", [&] {
+    unary_kernel_vec(
+        iter,
+        [=](scalar_t a) -> scalar_t {
+          return ((scalar_t)1) / std::sqrt(a);
+        },
+        [=](Vec256<scalar_t> a) { return a.rsqrt(); });
+  });
+}
+
+// TODO: Disable cont. branch to test more risky code
+
+#define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op)                             \
+  static void op##_kernel(TensorIterator& iter) {                             \
+    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), op##_vml_cpu, [&]() {            \
+      iter.serial_for_each(                                                   \
+          [&](int ntensor, char** data_, const int64_t* strides, int64_t n) { \
+            AT_ASSERT(ntensor == 2);                                          \
+            scalar_t* out_data = reinterpret_cast<scalar_t*>(data_[0]);       \
+            scalar_t* in_data = reinterpret_cast<scalar_t*>(data_[1]);        \
+            int64_t out_stride = strides[0] / sizeof(scalar_t);               \
+            int64_t in_stride = strides[1] / sizeof(scalar_t);                \
+            if (out_stride == 1 && in_stride == 1) {                          \
+              vml::v##op(out_data, in_data, n);                               \
+            } else {                                                          \
+              static constexpr int64_t WIDTH = 131072 / sizeof(scalar_t);     \
+              for (int64_t i = 0; i < n; i += WIDTH) {                        \
+                scalar_t buffer[WIDTH];                                       \
+                int64_t width = WIDTH;                                        \
+                width = std::min(width, n - i);                               \
+                for (int64_t j = 0; j < width; j++)                           \
+                  buffer[j] = in_data[in_stride * (i + j)];                   \
+                vml::v##op(buffer, buffer, width);                            \
+                for (int64_t j = 0; j < width; j++)                           \
+                  out_data[out_stride * (i + j)] = buffer[j];                 \
+              }                                                               \
+            }                                                                 \
+          },                                                                  \
+          {0, iter.numel()});                                                 \
+    });                                                                       \
+  }                                                                           \
+  REGISTER_DISPATCH(op##_stub, &op##_kernel)
+
 } // anonymous namespace
 
-REGISTER_DISPATCH(sigmoidImpl, &sigmoid_kernel)
+REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel)
+REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel)
 REGISTER_DISPATCH(bernoulli_mkl_stub, &bernoulli_mkl_kernel);
 
 // IMPLEMENT_FLOAT_KERNEL(ALL, abs)
@@ -235,7 +170,6 @@ IMPLEMENT_FLOAT_KERNEL(FLOATING, log10)
 IMPLEMENT_FLOAT_KERNEL(FLOATING, log1p)
 IMPLEMENT_FLOAT_KERNEL(FLOATING, log2)
 IMPLEMENT_FLOAT_KERNEL(FLOATING, round)
-IMPLEMENT_FLOAT_KERNEL(FLOATING, rsqrt)
 IMPLEMENT_FLOAT_KERNEL(FLOATING, sin)
 // IMPLEMENT_FLOAT_KERNEL(FLOATING, sinh)
 IMPLEMENT_FLOAT_KERNEL(FLOATING, sqrt)
diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.h b/aten/src/ATen/native/cpu/UnaryOpsKernel.h
deleted file mode 100644 (file)
index e1809d7..0000000
+++ /dev/null
@@ -1,59 +0,0 @@
-#pragma once
-
-#include <ATen/ATen.h>
-#include <ATen/native/DispatchStub.h>
-#include <ATen/Generator.h>
-#include <stdexcept>
-
-namespace at { namespace native {
-
-using unary_fn = void(*)(Tensor&, const Tensor&);
-
-DECLARE_DISPATCH(unary_fn, absImpl);
-DECLARE_DISPATCH(unary_fn, acosImpl);
-DECLARE_DISPATCH(unary_fn, asinImpl);
-DECLARE_DISPATCH(unary_fn, atanImpl);
-DECLARE_DISPATCH(unary_fn, ceilImpl);
-DECLARE_DISPATCH(unary_fn, cosImpl);
-// DECLARE_DISPATCH(unary_fn, coshImpl);
-DECLARE_DISPATCH(unary_fn, erfImpl);
-DECLARE_DISPATCH(unary_fn, erfcImpl);
-DECLARE_DISPATCH(unary_fn, expImpl);
-DECLARE_DISPATCH(unary_fn, expm1Impl);
-DECLARE_DISPATCH(unary_fn, floorImpl);
-DECLARE_DISPATCH(unary_fn, logImpl);
-DECLARE_DISPATCH(unary_fn, log10Impl);
-DECLARE_DISPATCH(unary_fn, log1pImpl);
-DECLARE_DISPATCH(unary_fn, log2Impl);
-DECLARE_DISPATCH(unary_fn, roundImpl);
-DECLARE_DISPATCH(unary_fn, rsqrtImpl);
-DECLARE_DISPATCH(unary_fn, sigmoidImpl);
-DECLARE_DISPATCH(unary_fn, sinImpl);
-// DECLARE_DISPATCH(unary_fn, sinhImpl);
-DECLARE_DISPATCH(unary_fn, sqrtImpl);
-DECLARE_DISPATCH(unary_fn, tanImpl);
-DECLARE_DISPATCH(unary_fn, tanhImpl);
-DECLARE_DISPATCH(unary_fn, truncImpl);
-
-DECLARE_DISPATCH(void(*)(Tensor&, const double, Generator *), bernoulli_mkl_stub);
-
-
-// Missing unary functions
-// digamma
-// lgamma
-
-// TODO: See below
-// erfinv
-// fill
-// frac
-// clone
-// contiguous
-// clamp/_min/_max
-// neg
-// reciprocal
-// sigmoid
-// sign
-// zero
-
-
-}} // namespace at::native