From: Christian Puhrsch Date: Tue, 26 Mar 2019 16:19:51 +0000 (-0700) Subject: Use TensorIterator for unary operations X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~628 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=cf094d4edcc7928f9a4a368b3e2eeb22579b29b0;p=platform%2Fupstream%2Fpytorch.git Use TensorIterator for unary operations Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18309 Differential Revision: D14591533 Pulled By: cpuhrsch fbshipit-source-id: a3b0788a481bddf1803c9f2d3289263d7364f8d7 --- diff --git a/aten/src/ATen/CPUApplyUtils.h b/aten/src/ATen/CPUApplyUtils.h index 7aadeeb..20753dc 100644 --- a/aten/src/ATen/CPUApplyUtils.h +++ b/aten/src/ATen/CPUApplyUtils.h @@ -498,37 +498,6 @@ inline void CPU_tensor_apply4( } } -template -inline void CPU_tensor_parallel_apply1( - Tensor tensor1, - const Op op, - int64_t grain_size = internal::GRAIN_SIZE) { - if (!_apply_preamble({tensor1})) - return; - if (tensor1.ndimension() < 8) { - parallel_for( - 0, - tensor1.numel(), - grain_size, - [&tensor1, &op](int64_t begin, int64_t end) { - apply_op( - end - begin, - begin, - op, - strided_tensor_iter_fixed(tensor1, true)); - }); - } else { - parallel_for( - 0, - tensor1.numel(), - grain_size, - [&tensor1, &op](int64_t begin, int64_t end) { - apply_op( - end - begin, begin, op, strided_tensor_iter(tensor1)); - }); - } -} - template inline void CPU_tensor_parallel_apply2( Tensor tensor1, diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 0ff0b0f..e96fa1c 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -105,6 +105,7 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) { return; } + // TODO: Replace this with TensorIterator! bool serial_path = false; if (self.numel() == src.numel()) { if (self.is_contiguous() && src.is_contiguous()) { diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 24392ee..9d7c3aa 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp index c588614..b47caca 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/native/TensorIterator.cpp @@ -479,6 +479,13 @@ std::unique_ptr TensorIterator::binary_op(Tensor& out, const Ten return builder.build(); } +std::unique_ptr TensorIterator::unary_op(Tensor& out, const Tensor& a) { + auto builder = TensorIterator::Builder(); + builder.add_output(out); + builder.add_input(a); + return builder.build(); +} + std::unique_ptr TensorIterator::reduce_op(Tensor& out, const Tensor& a) { AT_ASSERT(out.defined()); auto builder = TensorIterator::Builder(); diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h index c4a1a0b..f3510ab 100644 --- a/aten/src/ATen/native/TensorIterator.h +++ b/aten/src/ATen/native/TensorIterator.h @@ -126,6 +126,7 @@ struct CAFFE2_API TensorIterator { void foreach_reduced_elt(const loop_subiter_t& loop, bool parallelize=true); static std::unique_ptr binary_op(Tensor& out, const Tensor& a, const Tensor& b); + static std::unique_ptr unary_op(Tensor& out, const Tensor& a); static std::unique_ptr reduce_op(Tensor& out, const Tensor& a); int ndim() const { return shape_.size(); } diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 1dccca4..d690ba6 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -9,11 +9,13 @@ #include #include #include +#include #include #include #include -#include +#include +#include #include #include @@ -113,6 +115,22 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { return self.copy_(args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(M_PI) / 4.)); } + +Tensor sigmoid(const Tensor& self) { + Tensor result = at::empty({0}, self.options()); + return at::sigmoid_out(result, self); +} +Tensor& _sigmoid__cpu(Tensor& self) { + return at::sigmoid_out(self, self); +} +Tensor& _sigmoid_out_cpu(Tensor& result, const Tensor& self) { + checkBackend("sigmoid", {result}, Backend::CPU); + assert_no_internal_overlap(result, "sigmoid"); + auto iter = TensorIterator::unary_op(result, self); + sigmoid_stub(iter->device_type(), *iter); + return result; +} + // NB: If you use this macro, you may also need to add a CUDA forwarding // stub in CUDAUnaryOps @@ -121,18 +139,14 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { Tensor result = at::empty({0}, self.options()); \ return at::op##_out(result, self); \ } \ - Tensor& _##op##__cpu(Tensor& self_) { \ - if (self_.numel() > 0) { \ - Tensor self = sort_strides(self_); \ - op##Impl(kCPU, self, self); \ - } \ - return self_; \ + Tensor& _##op##__cpu(Tensor& self) { \ + return at::op##_out(self, self); \ } \ Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \ - result.resize_(self.sizes()); \ - if (result.numel() > 0) { \ - op##Impl(kCPU, result, self); \ - } \ + checkBackend(#op, {result}, Backend::CPU); \ + assert_no_internal_overlap(result, #op); \ + auto iter = TensorIterator::unary_op(result, self); \ + op##_stub(iter->device_type(), *iter); \ return result; \ } @@ -145,8 +159,10 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { return at::op##_out(self, self); \ } \ Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \ + checkBackend(#op, {result}, Backend::CPU); \ + assert_no_internal_overlap(result, #op); \ result.resize_(self.sizes()); \ - return at::legacy::th::_th_##op##_out(result, self); \ + return at::legacy::th::_th_##op##_out(result, self); \ } // NB: Temp. defaulting to TH implementation of abs due to issues with Apple @@ -169,7 +185,6 @@ IMPLEMENT_UNARY_OP_VEC(log1p) IMPLEMENT_UNARY_OP_VEC(log2) IMPLEMENT_UNARY_OP_VEC(round) IMPLEMENT_UNARY_OP_VEC(rsqrt) -IMPLEMENT_UNARY_OP_VEC(sigmoid) IMPLEMENT_UNARY_OP_VEC(sin) IMPLEMENT_UNARY_OP_TH(sinh) IMPLEMENT_UNARY_OP_VEC(sqrt) @@ -177,29 +192,29 @@ IMPLEMENT_UNARY_OP_VEC(tan) IMPLEMENT_UNARY_OP_VEC(tanh) IMPLEMENT_UNARY_OP_VEC(trunc) -DEFINE_DISPATCH(absImpl); -DEFINE_DISPATCH(acosImpl); -DEFINE_DISPATCH(asinImpl); -DEFINE_DISPATCH(atanImpl); -DEFINE_DISPATCH(ceilImpl); -DEFINE_DISPATCH(cosImpl); -DEFINE_DISPATCH(erfImpl); -DEFINE_DISPATCH(erfcImpl); -DEFINE_DISPATCH(expImpl); -DEFINE_DISPATCH(expm1Impl); -DEFINE_DISPATCH(floorImpl); -DEFINE_DISPATCH(logImpl); -DEFINE_DISPATCH(log10Impl); -DEFINE_DISPATCH(log1pImpl); -DEFINE_DISPATCH(log2Impl); -DEFINE_DISPATCH(roundImpl); -DEFINE_DISPATCH(rsqrtImpl); -DEFINE_DISPATCH(sigmoidImpl); -DEFINE_DISPATCH(sinImpl); -DEFINE_DISPATCH(sqrtImpl); -DEFINE_DISPATCH(tanImpl); -DEFINE_DISPATCH(tanhImpl); -DEFINE_DISPATCH(truncImpl); +DEFINE_DISPATCH(abs_stub); +DEFINE_DISPATCH(acos_stub); +DEFINE_DISPATCH(asin_stub); +DEFINE_DISPATCH(atan_stub); +DEFINE_DISPATCH(ceil_stub); +DEFINE_DISPATCH(cos_stub); +DEFINE_DISPATCH(erf_stub); +DEFINE_DISPATCH(erfc_stub); +DEFINE_DISPATCH(exp_stub); +DEFINE_DISPATCH(expm1_stub); +DEFINE_DISPATCH(floor_stub); +DEFINE_DISPATCH(log_stub); +DEFINE_DISPATCH(log10_stub); +DEFINE_DISPATCH(log1p_stub); +DEFINE_DISPATCH(log2_stub); +DEFINE_DISPATCH(round_stub); +DEFINE_DISPATCH(rsqrt_stub); +DEFINE_DISPATCH(sigmoid_stub); +DEFINE_DISPATCH(sin_stub); +DEFINE_DISPATCH(sqrt_stub); +DEFINE_DISPATCH(tan_stub); +DEFINE_DISPATCH(tanh_stub); +DEFINE_DISPATCH(trunc_stub); } } // namespace at diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h new file mode 100644 index 0000000..e60bda8 --- /dev/null +++ b/aten/src/ATen/native/UnaryOps.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { struct TensorIterator; } + +namespace at { namespace native { + +using unary_fn = void(*)(TensorIterator&); + +DECLARE_DISPATCH(unary_fn, abs_stub); +DECLARE_DISPATCH(unary_fn, acos_stub); +DECLARE_DISPATCH(unary_fn, asin_stub); +DECLARE_DISPATCH(unary_fn, atan_stub); +DECLARE_DISPATCH(unary_fn, ceil_stub); +DECLARE_DISPATCH(unary_fn, cos_stub); +// DECLARE_DISPATCH(unary_fn, cosh_stub); +DECLARE_DISPATCH(unary_fn, erf_stub); +DECLARE_DISPATCH(unary_fn, erfc_stub); +DECLARE_DISPATCH(unary_fn, exp_stub); +DECLARE_DISPATCH(unary_fn, expm1_stub); +DECLARE_DISPATCH(unary_fn, floor_stub); +DECLARE_DISPATCH(unary_fn, log_stub); +DECLARE_DISPATCH(unary_fn, log10_stub); +DECLARE_DISPATCH(unary_fn, log1p_stub); +DECLARE_DISPATCH(unary_fn, log2_stub); +DECLARE_DISPATCH(unary_fn, round_stub); +DECLARE_DISPATCH(unary_fn, rsqrt_stub); +DECLARE_DISPATCH(unary_fn, sigmoid_stub); +DECLARE_DISPATCH(unary_fn, sin_stub); +// DECLARE_DISPATCH(unary_fn, sinh_stub); +DECLARE_DISPATCH(unary_fn, sqrt_stub); +DECLARE_DISPATCH(unary_fn, tan_stub); +DECLARE_DISPATCH(unary_fn, tanh_stub); +DECLARE_DISPATCH(unary_fn, trunc_stub); + +DECLARE_DISPATCH(void(*)(Tensor&, const double, Generator *), bernoulli_mkl_stub); + +// Missing unary functions +// digamma +// lgamma +// erfinv +// fill +// frac +// clone +// contiguous +// clamp/_min/_max +// neg +// reciprocal +// sign +// zero +}} // namespace at::native diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index 62b7f56..734c2fe 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -33,6 +33,13 @@ static inline bool is_binary_contiguous_s2(const int64_t* strides) { strides[2] == 0; } +// all two operands contiguous +template +static inline bool is_unary_contiguous(const int64_t* strides) { + return strides[0] == sizeof(typename traits::result_type) && + strides[1] == sizeof(typename traits::arg1_t); +} + // result is static inline bool is_reduction(char** data, const int64_t* strides) { return strides[0] == 0 && @@ -40,6 +47,24 @@ static inline bool is_reduction(char** data, const int64_t* strides) { data[0] == data[1]; } +#define UNARY_LOOP_HEADER(func_t, data, strides) \ + using traits = unary_function_traits; \ + using arg0_t = typename traits::result_type; \ + using arg1_t = typename traits::arg1_t; \ + char* out_ptr = data[0]; \ + const char* in1_ptr = data[1]; \ + int64_t s0 = strides[0], s1 = strides[1]; + +#define UNARY_VEC_HEADER(func_t) \ + using traits = unary_function_traits; \ + using scalar_t = typename traits::result_type; \ + using Vec = Vec256; + +#define UNARY_VEC_LOOP_HEADER(func_t, data) \ + UNARY_VEC_HEADER(func_t) \ + char* out_ptr = data[0]; \ + const char* in1_ptr = data[1]; + #define LOOP_HEADER(func_t, data, strides) \ using traits = binary_function_traits; \ using arg0_t = typename traits::result_type; \ @@ -61,6 +86,34 @@ static inline bool is_reduction(char** data, const int64_t* strides) { const char* in1_ptr = data[1]; \ const char* in2_ptr = data[2]; +// Basic loop unary operation (one input, one output). May be auto-vectorized +// by the compiler. +template +static inline void unary_loop(char** data, const int64_t* strides, int64_t i, int64_t n, func_t op) { + UNARY_LOOP_HEADER(func_t, data, strides) + for (; i < n; i++) { + arg1_t in1 = *(arg1_t*)(in1_ptr + i * s1); + arg0_t out = op(in1); + *(arg0_t*)(out_ptr + i * s0) = out; + } +} + +// computes out = op(in1) +template +static inline void vectorized_unary_loop(char** data, int64_t n, func_t op, vec_func_t vop) { + UNARY_VEC_LOOP_HEADER(func_t, data) + int64_t i = 0; + for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) { + auto a1 = Vec::loadu(in1_ptr + i * sizeof(scalar_t)); + auto a2 = Vec::loadu(in1_ptr + (i + Vec::size()) * sizeof(scalar_t)); + auto out1 = vop(a1); + auto out2 = vop(a2); + out1.store(out_ptr + i * sizeof(scalar_t)); + out2.store(out_ptr + (i + Vec::size()) * sizeof(scalar_t)); + } + int64_t strides[] = { sizeof(scalar_t), sizeof(scalar_t) }; + unary_loop(data, strides, i, n, op); +} // Basic loop binary operation (two inputs, one output). May be auto-vectorized // by the compiler. @@ -209,6 +262,36 @@ static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, } template +void unary_kernel(TensorIterator& iter, func_t op) { + using traits = unary_function_traits; + + iter.for_each([&](int ntensor, char** data, const int64_t* strides, int64_t n) { + // Specializations to encourage auto-vectorization (trick from Numpy's loops.c.src) + if (is_unary_contiguous(strides)) { + unary_loop(data, strides, 0, n, op); + } else { + unary_loop(data, strides, 0, n, op); + } + }); +} + +template +void unary_kernel_vec(TensorIterator& iter, func_t op, vec_func_t vop) { + using traits = unary_function_traits; + static_assert( + std::is_same::value, + "all types must match"); + + iter.for_each([&](int ntensor, char** data, const int64_t* strides, int64_t n) { + if (is_unary_contiguous(strides)) { + vectorized_unary_loop(data, n, op, vop); + } else { + unary_loop(data, strides, 0, n, op); + } + }); +} + +template void binary_kernel(TensorIterator& iter, func_t op) { using traits = binary_function_traits; diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index f53733b..c6d309c 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -1,5 +1,3 @@ -#include - #include #include #include @@ -7,14 +5,18 @@ #include #include #include -#include +#include + #include -#include -#include +#include +#include + #include -#ifdef __AVX2__ -#include -#endif +#include +#include + +#include + #if AT_MKL_ENABLED() #include @@ -28,95 +30,21 @@ namespace { using namespace vec256; -template -static int64_t _sigmoid(scalar_t* x, scalar_t* y, int64_t size); - -// This should be a temporary solution until we understand why SLEEF is slower -// for sigmoid - -template <> -int64_t _sigmoid(float* x, float* y, int64_t size) { - using Vec = Vec256; - int64_t i = 0; - for (; i < size - (size % (2 * Vec::size())); i += 2 * Vec::size()) { - Vec ret = Vec::loadu(y + i); - Vec ret2 = Vec::loadu(y + i + Vec::size()); - ret = ret.neg(); - ret2 = ret2.neg(); -#if defined(__AVX2__) && !defined(_MSC_VER) - ret = exp256_ps(ret); - ret2 = exp256_ps(ret2); -#else - ret = ret.exp(); - ret2 = ret2.exp(); -#endif - ret = Vec((float)(1)) + ret; - ret2 = Vec((float)(1)) + ret2; - ret = ret.reciprocal(); - ret2 = ret2.reciprocal(); - ret.store(x + i); - ret2.store(x + i + Vec::size()); - } - return i; -} - -template <> -int64_t _sigmoid(double* x, double* y, int64_t size) { - using Vec = Vec256; - int64_t i = 0; - for (; i < size - (size % (2 * Vec::size())); i += 2 * Vec::size()) { - Vec ret = Vec::loadu(y + i); - Vec ret2 = Vec::loadu(y + i + Vec::size()); - ret = ret.neg(); - ret2 = ret2.neg(); - ret = ret.exp(); - ret2 = ret2.exp(); - ret = Vec((double)(1)) + ret; - ret2 = Vec((double)(1)) + ret2; - ret = ret.reciprocal(); - ret2 = ret2.reciprocal(); - ret.store(x + i); - ret2.store(x + i + Vec::size()); - } - return i; -} - -static void sigmoid_kernel(Tensor& result, const Tensor& self) { - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "sigmoid", [&] { - using Vec = Vec256; - CPU_tensor_parallel_kernel_apply2( - result, - self, - [](int64_t size, - scalar_t* x, - scalar_t* y, - int64_t stridex, - int64_t stridey) { - int64_t i = 0; - if (stridex == 1 && stridey == 1) { - i = _sigmoid(x, y, size); - } - for (; i < size; i += Vec::size()) { - scalar_t buffer[Vec::size()]; - int64_t width = Vec::size(); - width = std::min(width, size - i); - for (int64_t j = 0; j < width; j++) { - buffer[j] = y[stridey * (i + j)]; - } - Vec ret = Vec::loadu(buffer); - ret = Vec((scalar_t)(0)) - ret; - ret = ret.exp(); - ret = Vec((scalar_t)(1)) + ret; - ret = ret.reciprocal(); - ret.store(buffer); - for (int64_t j = 0; j < width; j++) - x[stridex * (i + j)] = buffer[j]; - } +static void sigmoid_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "sigmoid_cpu", [&]() { + unary_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { return (1 / (1 + std::exp((-a)))); }, + [=](Vec256 a) { + a = Vec256((scalar_t)(0)) - a; + a = a.exp(); + a = Vec256((scalar_t)(1)) + a; + a = a.reciprocal(); + return a; }); }); } - #if !AT_MKL_ENABLED() void bernoulli_mkl_kernel(Tensor &output, const double p, Generator* gen) { // Use AT_ASSERTM because this should never be reached, and AT_ASSERTM tells @@ -175,47 +103,54 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) { } #endif -#define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op) \ - static void op##_kernel(Tensor& result, const Tensor& self) { \ - checkBackend(#op, {result}, Backend::CPU); \ - AT_DISPATCH_##dispatchtypes##_TYPES(self.scalar_type(), #op, [&] { \ - if (self.is_contiguous() && result.is_contiguous()) { \ - vml::v##op( \ - result.data(), self.data(), self.numel()); \ - \ - } else { \ - assert_no_internal_overlap(result, #op); \ - static constexpr int64_t WIDTH = 131072 / sizeof(scalar_t); \ - CPU_tensor_parallel_kernel_apply2( \ - result, \ - self, \ - [](int64_t size, \ - scalar_t* x, \ - scalar_t* y, \ - int64_t stridex, \ - int64_t stridey) { \ - if (stridex == 1 && stridey == 1) { \ - vml::v##op(x, y, size); \ - } else { \ - for (int64_t i = 0; i < size; i += WIDTH) { \ - scalar_t buffer[WIDTH]; \ - int64_t width = WIDTH; \ - width = std::min(width, size - i); \ - for (int64_t j = 0; j < width; j++) \ - buffer[j] = y[stridey * (i + j)]; \ - vml::v##op(buffer, buffer, width); \ - for (int64_t j = 0; j < width; j++) \ - x[stridex * (i + j)] = buffer[j]; \ - } \ - } \ - }); \ - } \ - }); \ - } \ - REGISTER_DISPATCH(op##Impl, &op##_kernel) +static void rsqrt_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "rsqrt_cpu", [&] { + unary_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { + return ((scalar_t)1) / std::sqrt(a); + }, + [=](Vec256 a) { return a.rsqrt(); }); + }); +} + +// TODO: Disable cont. branch to test more risky code + +#define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op) \ + static void op##_kernel(TensorIterator& iter) { \ + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), op##_vml_cpu, [&]() { \ + iter.serial_for_each( \ + [&](int ntensor, char** data_, const int64_t* strides, int64_t n) { \ + AT_ASSERT(ntensor == 2); \ + scalar_t* out_data = reinterpret_cast(data_[0]); \ + scalar_t* in_data = reinterpret_cast(data_[1]); \ + int64_t out_stride = strides[0] / sizeof(scalar_t); \ + int64_t in_stride = strides[1] / sizeof(scalar_t); \ + if (out_stride == 1 && in_stride == 1) { \ + vml::v##op(out_data, in_data, n); \ + } else { \ + static constexpr int64_t WIDTH = 131072 / sizeof(scalar_t); \ + for (int64_t i = 0; i < n; i += WIDTH) { \ + scalar_t buffer[WIDTH]; \ + int64_t width = WIDTH; \ + width = std::min(width, n - i); \ + for (int64_t j = 0; j < width; j++) \ + buffer[j] = in_data[in_stride * (i + j)]; \ + vml::v##op(buffer, buffer, width); \ + for (int64_t j = 0; j < width; j++) \ + out_data[out_stride * (i + j)] = buffer[j]; \ + } \ + } \ + }, \ + {0, iter.numel()}); \ + }); \ + } \ + REGISTER_DISPATCH(op##_stub, &op##_kernel) + } // anonymous namespace -REGISTER_DISPATCH(sigmoidImpl, &sigmoid_kernel) +REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel) +REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel) REGISTER_DISPATCH(bernoulli_mkl_stub, &bernoulli_mkl_kernel); // IMPLEMENT_FLOAT_KERNEL(ALL, abs) @@ -235,7 +170,6 @@ IMPLEMENT_FLOAT_KERNEL(FLOATING, log10) IMPLEMENT_FLOAT_KERNEL(FLOATING, log1p) IMPLEMENT_FLOAT_KERNEL(FLOATING, log2) IMPLEMENT_FLOAT_KERNEL(FLOATING, round) -IMPLEMENT_FLOAT_KERNEL(FLOATING, rsqrt) IMPLEMENT_FLOAT_KERNEL(FLOATING, sin) // IMPLEMENT_FLOAT_KERNEL(FLOATING, sinh) IMPLEMENT_FLOAT_KERNEL(FLOATING, sqrt) diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.h b/aten/src/ATen/native/cpu/UnaryOpsKernel.h deleted file mode 100644 index e1809d7..0000000 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace at { namespace native { - -using unary_fn = void(*)(Tensor&, const Tensor&); - -DECLARE_DISPATCH(unary_fn, absImpl); -DECLARE_DISPATCH(unary_fn, acosImpl); -DECLARE_DISPATCH(unary_fn, asinImpl); -DECLARE_DISPATCH(unary_fn, atanImpl); -DECLARE_DISPATCH(unary_fn, ceilImpl); -DECLARE_DISPATCH(unary_fn, cosImpl); -// DECLARE_DISPATCH(unary_fn, coshImpl); -DECLARE_DISPATCH(unary_fn, erfImpl); -DECLARE_DISPATCH(unary_fn, erfcImpl); -DECLARE_DISPATCH(unary_fn, expImpl); -DECLARE_DISPATCH(unary_fn, expm1Impl); -DECLARE_DISPATCH(unary_fn, floorImpl); -DECLARE_DISPATCH(unary_fn, logImpl); -DECLARE_DISPATCH(unary_fn, log10Impl); -DECLARE_DISPATCH(unary_fn, log1pImpl); -DECLARE_DISPATCH(unary_fn, log2Impl); -DECLARE_DISPATCH(unary_fn, roundImpl); -DECLARE_DISPATCH(unary_fn, rsqrtImpl); -DECLARE_DISPATCH(unary_fn, sigmoidImpl); -DECLARE_DISPATCH(unary_fn, sinImpl); -// DECLARE_DISPATCH(unary_fn, sinhImpl); -DECLARE_DISPATCH(unary_fn, sqrtImpl); -DECLARE_DISPATCH(unary_fn, tanImpl); -DECLARE_DISPATCH(unary_fn, tanhImpl); -DECLARE_DISPATCH(unary_fn, truncImpl); - -DECLARE_DISPATCH(void(*)(Tensor&, const double, Generator *), bernoulli_mkl_stub); - - -// Missing unary functions -// digamma -// lgamma - -// TODO: See below -// erfinv -// fill -// frac -// clone -// contiguous -// clamp/_min/_max -// neg -// reciprocal -// sigmoid -// sign -// zero - - -}} // namespace at::native