From: Iurii Zdebskyi Date: Thu, 7 Mar 2019 21:38:59 +0000 (-0800) Subject: Refactor dispatcher (#17753) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~937 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6227afb305e01d73633742383095f42dffd29f63;p=platform%2Fupstream%2Fpytorch.git Refactor dispatcher (#17753) Summary: This is a side PR for a bool tensor feature. The idea of this change came from a feedback received in this [PR](https://github.com/pytorch/pytorch/pull/17376). Pull Request resolved: https://github.com/pytorch/pytorch/pull/17753 Differential Revision: D14367989 Pulled By: izdeby fbshipit-source-id: 4fa380e56e20f18e480be68920170dbc3a4eb91c --- diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 7c8bff3..20d7f2d 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -81,73 +81,48 @@ } \ }() -#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \ - [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ - } \ - }() +template +struct MyTemplate; -#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ - } \ - }() +template<> +struct MyTemplate { + using type = at::Half; +}; -#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \ - [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ - } \ +#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + [&] { \ + const at::Type& the_type = TYPE; \ + switch (the_type.scalarType()) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate::type, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + } \ }() -#define AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(TYPE, NAME, ...) \ - [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ - } \ +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ + [&] { \ + const at::Type& the_type = TYPE; \ + switch (the_type.scalarType()) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate::type, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + } \ }() diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 871c274..a8c7bab 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -20,7 +20,7 @@ void _copy__cpu(at::Tensor& self, const at::Tensor& src) { template void _copy__cpu(at::Tensor& self, const at::Tensor& src) { AT_CHECK(self.numel() == src.numel(), "sizes do not match"); - AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cpu", [&]() { _copy__cpu(self, src); }); } @@ -42,8 +42,8 @@ Tensor& _s_copy__cpu(Tensor& self, const Tensor& src, bool non_blocking) { _s_copy_from(src, self, non_blocking); return self; } - AT_DISPATCH_ALL_TYPES_AND_HALF( - self.type(), "_copy__cpu", [&]() { ::_copy__cpu(self, src); }); + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.type(), "_copy__cpu", [&]() { ::_copy__cpu(self, src); }); return self; } @@ -58,8 +58,8 @@ void _copy_same_type_transpose_(Tensor& self, const Tensor& src) { } Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options()); - AT_DISPATCH_ALL_TYPES_AND_HALF( - self.type(), "_copy_same_type_transpose_", [&]() { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.type(), "_copy_same_type_transpose_", [&]() { scalar_t* sp = src.data(); scalar_t* rp = self.data(); scalar_t* bp = buf.data(); @@ -114,12 +114,13 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) { } else { #ifdef _OPENMP if (!in_parallel_region()) { - AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() { - at::CPU_tensor_parallel_apply2( - self, src, [](scalar_t& self_val, const scalar_t& src_val) { - self_val = src_val; - }); - }); + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() { + at::CPU_tensor_parallel_apply2( + self, src, [](scalar_t& self_val, const scalar_t& src_val) { + self_val = src_val; + }); + }); } else { serial_path = true; } @@ -132,12 +133,13 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) { } if (serial_path) { - AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() { - at::CPU_tensor_apply2( - self, src, [](scalar_t& self_val, const scalar_t& src_val) { - self_val = src_val; - }); - }); + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() { + at::CPU_tensor_apply2( + self, src, [](scalar_t& self_val, const scalar_t& src_val) { + self_val = src_val; + }); + }); } } diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index 4de16d7..c635c9d 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -18,8 +18,8 @@ Scalar item(const Tensor& self) { Scalar _local_scalar_dense_cpu(const Tensor& self) { Scalar r; - AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX( - self.type(), "_local_scalar_dense_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND( + at::ScalarType::Half, self.type(), "_local_scalar_dense_cpu", [&] { scalar_t value = *self.data(); r = Scalar(value); }); diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index caf0364..03ef8df 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -14,15 +14,16 @@ namespace { constexpr int64_t COPY_GRAIN_SIZE = 20000; static void copy_kernel_impl(Tensor& dst, const Tensor& src) { - AT_DISPATCH_ALL_TYPES_AND_HALF(dst.type(), "copy_kernel_impl", [&]() { - scalar_t* self_ptr = dst.data(); - scalar_t* src_ptr = src.data(); - - auto sample = [&](int64_t begin, int64_t end) { - int64_t len = end - begin; - scalar_t* self_seg = self_ptr + begin; - scalar_t* src_seg = src_ptr + begin; - at::vec256::convert(src_seg, self_seg, len); + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, dst.type(), "copy_kernel_impl", [&]() { + scalar_t* self_ptr = dst.data(); + scalar_t* src_ptr = src.data(); + + auto sample = [&](int64_t begin, int64_t end) { + int64_t len = end - begin; + scalar_t* self_seg = self_ptr + begin; + scalar_t* src_seg = src_ptr + begin; + at::vec256::convert(src_seg, self_seg, len); }; parallel_for(0, dst.numel(), COPY_GRAIN_SIZE, sample); diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index c89621d..0d93112 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -92,7 +92,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef } void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(0), "index", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index", [&] { cpu_index_kernel(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) { *(scalar_t*)dst = *(scalar_t*)(src + offset); }); @@ -101,7 +101,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { // NOTE: duplicate indices are only supported if accumulate is true. - AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(0), "index_put", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index_put", [&] { if (accumulate) { // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case, // this needs to be thread-safe. diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index b4b7894..83e879a 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -286,7 +286,7 @@ void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t va } static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar value) { - AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "threshold", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "threshold", [&] { threshold_kernel_impl(iter, threshold.to(), value.to()); }); } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 2ac7f5f..1f0184b 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -493,7 +493,7 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c self_row_stride = self.stride(-2), self_col_stride = self.stride(-1); dim3 dim_block = cuda::getApplyBlock(); dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches); - AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), name, [&]{ + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), name, [&]{ triu_tril_kernel <<>>( result.data(), self.data(), k, mat_size, diff --git a/aten/src/ATen/native/cuda/BinaryOpsKernel.cu b/aten/src/ATen/native/cuda/BinaryOpsKernel.cu index fa31a2d..5cb1212 100644 --- a/aten/src/ATen/native/cuda/BinaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryOpsKernel.cu @@ -22,7 +22,7 @@ void add_kernel_impl(TensorIterator& iter, Scalar alpha_scalar) { } static void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) { - AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "add", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "add", [&]() { add_kernel_impl(iter, alpha_scalar); }); } @@ -74,7 +74,7 @@ void mul_kernel_impl(TensorIterator& iter) { } static void mul_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "mul", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "mul", [&]() { mul_kernel_impl(iter); }); } diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index d4dc8d4..70b9236 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -9,8 +9,8 @@ namespace native { Scalar _local_scalar_dense_cuda(const Tensor& self) { Scalar r; - AT_DISPATCH_ALL_TYPES_AND_HALF( - self.type(), "_local_scalar_dense_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.type(), "_local_scalar_dense_cuda", [&] { scalar_t value; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream)); diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index cb96963..ab85b48 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -169,7 +169,7 @@ void copy_from_cpu(Tensor& dst, const Tensor& src) { cudaMemcpyHostToDevice, stream)); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_from_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu", [&]() { copy_device_to_device(dst, dst_contig); }); } @@ -202,7 +202,7 @@ void copy_from_cpu_async_(Tensor& dst, const Tensor& src) { CUDAGuard device_guard(dst.device()); CUDAStream stream = getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_from_cpu_async", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu_async", [&]() { AT_CUDA_CHECK(cudaMemcpyAsync( dst.data(), src.data(), @@ -225,7 +225,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) { CUDAGuard device_guard(src.device()); CUDAStream stream = getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_to_cpu_async", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_to_cpu_async", [&]() { AT_CUDA_CHECK(cudaMemcpyAsync( dst.data(), src.data(), @@ -240,7 +240,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) { template void _copy__cuda(Tensor& dst, const Tensor& src, bool non_blocking) { AT_CHECK(dst.numel() == src.numel(), "sizes do not match"); - AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cuda", [&]() { if (dst.is_cuda() && src.is_cuda()) { copy_device_to_device(dst, src); } else if (dst.is_cuda()) { @@ -279,7 +279,7 @@ namespace at { namespace native { Tensor& _s_copy__cuda(Tensor& self, const Tensor& src, bool non_blocking) { - AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy__cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "_copy__cuda", [&]() { ::_copy__cuda(self, src, non_blocking); }); return self; diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu index 00a1b34..c262384 100644 --- a/aten/src/ATen/native/cuda/Distributions.cu +++ b/aten/src/ATen/native/cuda/Distributions.cu @@ -210,21 +210,22 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) { Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) { auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA))); - AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "bernoulli_tensor_cuda_self_", [&] { - const at::Type& p_type = p.type(); - using self_t = scalar_t; - auto seeds = next_philox_seed(gen, 10); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.type(), "bernoulli_tensor_cuda_p_", [&] { - using p_t = scalar_t; - return bernoulli_tensor_cuda_kernel(self, p, seeds); - }); + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.type(), "bernoulli_tensor_cuda_self_", [&] { + const at::Type& p_type = p.type(); + using self_t = scalar_t; + auto seeds = next_philox_seed(gen, 10); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.type(), "bernoulli_tensor_cuda_p_", [&] { + using p_t = scalar_t; + return bernoulli_tensor_cuda_kernel(self, p, seeds); + }); }); return self; } Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) { AT_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p); - AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "bernoulli_scalar_cuda_", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "bernoulli_scalar_cuda_", [&] { auto seeds = next_philox_seed(gen, 10); bernoulli_scalar_cuda_kernel(self, p, seeds); }); diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 6abd9f0..be69cde 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -81,7 +81,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra } static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index", [&] { using dtype = OpaqueType; index_kernel_impl(iter, index_size, index_stride); }); @@ -90,7 +90,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { AT_ASSERTM(!accumulate, "index_put does not support accumulate=true"); - AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index_put", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index_put", [&] { using dtype = OpaqueType; index_put_kernel_impl(iter, index_size, index_stride); }); diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index 6210b47..227fd81 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -99,7 +99,7 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step } Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { - AT_DISPATCH_ALL_TYPES_AND_HALF(result.type(), "range", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "range", [&]() { using accscalar_t = at::acc_type; auto xstart = start.to(); auto xend = end.to(); @@ -130,7 +130,7 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { } Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { - AT_DISPATCH_ALL_TYPES_AND_HALF(result.type(), "arange", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "arange", [&]() { using accscalar_t = at::acc_type; auto xstart = start.to(); auto xend = end.to(); diff --git a/aten/src/ATen/native/cuda/SortingKthValue.cu b/aten/src/ATen/native/cuda/SortingKthValue.cu index d3c488e..4018d1d 100644 --- a/aten/src/ATen/native/cuda/SortingKthValue.cu +++ b/aten/src/ATen/native/cuda/SortingKthValue.cu @@ -233,14 +233,14 @@ std::tuple kthvalue_out_cuda( int64_t k, int64_t dim, bool keepdim) { - AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "kthvalue", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "kthvalue", [&] { kthvalue_cuda_template(values, indices, self, k, dim, keepdim); }); return std::forward_as_tuple(values, indices); } Tensor median_cuda(const Tensor& self) { - return AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "median", [&] { + return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "median", [&] { return median_cuda_template(self); }); } diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index f7612d5..f7d9b68 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -33,7 +33,7 @@ Tensor _s_where_cuda( const Tensor& self, const Tensor& other) { Tensor ret = at::empty(self.sizes(), self.options()); - AT_DISPATCH_ALL_TYPES_AND_HALF(ret.type(), "where", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.type(), "where", [&] { where_cuda(ret, condition, self, other); }); return ret; diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index 58b90bc..4a0a7f1 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -322,7 +322,7 @@ Tensor tril_indices_cuda( cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()), "unable to get dim grid"); - AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "tril_indices_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "tril_indices_cuda", [&] { tril_indices_kernel<<< dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>( tensor.data(), @@ -398,7 +398,7 @@ Tensor triu_indices_cuda( cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()), "unable to get dim grid"); - AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "triu_indices_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "triu_indices_cuda", [&] { triu_indices_kernel<<< dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>( tensor.data(), diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index 4eb25de..5fccc10 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -87,7 +87,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { // use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) { - AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] { auto in_tensor_info = cuda::detail::getTensorInfo(in_tensor); auto out_tensor_info = cuda::detail::getTensorInfo(out_tensor); int flip_dim = in_tensor_info.collapseDims(flip_dims[0]); @@ -119,7 +119,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { } } - AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] { flip_cuda_kernel<<>>( in_tensor.data(), out_tensor.data(), N, flip_dims_t.toType(CUDA(kLong)).data(), flip_dims_size, strides_t.toType(CUDA(kLong)).data(), stride_contiguous.toType(CUDA(kLong)).data(), shape_t.toType(CUDA(kLong)).data(), total_dims); @@ -177,7 +177,7 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { auto total_dims = in_tensor.dim(); - AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "roll_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "roll_cuda", [&] { roll_cuda_kernel<<>>( in_tensor.data(), out_tensor.data(), N, dim, start, diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index 23e4cfb..fc144d6 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -94,8 +94,8 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { int64_t stride = at::prod_intlist(values.sizes().slice(1)); dim3 grid(THCCeilDiv(newNnz, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128)); dim3 block(32, 4); - AT_DISPATCH_ALL_TYPES_AND_HALF( - values.type(), "coalesce_sparse_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half,values.type(), "coalesce_sparse_cuda", [&] { using cuda_accscalar_t = acc_type; apply::coalesceValuesKernel<<>>( uniqueOffsets.data(), diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index ba0a8de..cac6835 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -295,8 +295,8 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR if (sparse.dense_dim() == 0) { AT_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); - AT_DISPATCH_ALL_TYPES_AND_HALF( - values.type(), "add_out_dense_sparse_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] { apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> <<>>( TensorCAddOp(value.to()), @@ -309,8 +309,8 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR // sparseElementwiseKernel needs values to be contiguous too values = values.contiguous(); - AT_DISPATCH_ALL_TYPES_AND_HALF( - values.type(), "add_out_dense_sparse_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] { apply::sparseElementwiseKernel, uint64_t, scalar_t> <<>>( TensorCAddOp(value.to()), @@ -323,8 +323,8 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR // FIXME: at some point we can wrap the scale into indexAdd // NB: Purposely not inplace! - AT_DISPATCH_ALL_TYPES_AND_HALF( - values.type(), "add_out_dense_sparse_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] { if (value.to() != static_cast(1)) { values = values.mul(value); } @@ -378,8 +378,8 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const LongTensor s_indices_ = src._indices(); Tensor s_values_ = src._values(); - AT_DISPATCH_ALL_TYPES_AND_HALF( - s_values_.type(), "add_out_sparse_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, s_values_.type(), "add_out_sparse_cuda", [&] { if (value.to() != static_cast(1)) { s_values_ = s_values_.mul(value); } @@ -448,8 +448,8 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons AT_CHECK(cuda::getApplyGrid(valueSize, grid, curDevice), "mul: Argument #0: tensor too large or too many dimensions"); LongTensor resultNnz = at::empty({1}, CUDA(kLong)); - AT_DISPATCH_ALL_TYPES_AND_HALF( - t_values_.type(), "mul_out_sparse_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, t_values_.type(), "mul_out_sparse_cuda", [&] { apply::valueSparseIntersectionKernel, uint64_t, scalar_t> <<>>( TensorMulOp(),