From: Roy Li Date: Sat, 9 Mar 2019 00:39:04 +0000 (-0800) Subject: Change Dispatch.h to use ScalarType over Type X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~911 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3aeb78079bcd68282fe9117088e138b77318e288;p=platform%2Fupstream%2Fpytorch.git Change Dispatch.h to use ScalarType over Type Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17527 Reviewed By: zou3519 Differential Revision: D14235395 fbshipit-source-id: 3f53e33f6794f1f14c2edf79014b8ef8397822c5 --- diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 20d7f2d..51dcae6 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -12,31 +12,28 @@ #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ + switch (TYPE) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } \ }() #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ + switch (TYPE) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } \ }() #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ + switch (TYPE) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ @@ -47,28 +44,26 @@ AT_PRIVATE_CASE_TYPE( \ at::ScalarType::ComplexHalf, std::complex, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } \ }() #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ + switch (TYPE) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ + switch (TYPE) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ @@ -77,7 +72,7 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } \ }() @@ -91,8 +86,7 @@ struct MyTemplate { #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ + switch (TYPE) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ @@ -102,14 +96,13 @@ struct MyTemplate { AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate::type, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ + switch (TYPE) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ @@ -123,6 +116,6 @@ struct MyTemplate { AT_PRIVATE_CASE_TYPE( \ at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } \ }() diff --git a/aten/src/ATen/detail/ScalarTypeConversions.h b/aten/src/ATen/detail/ScalarTypeConversions.h index 76fb0dc..ef04271 100644 --- a/aten/src/ATen/detail/ScalarTypeConversions.h +++ b/aten/src/ATen/detail/ScalarTypeConversions.h @@ -9,14 +9,14 @@ namespace at { namespace detail { template inline T load(const void* data, ScalarType src_type) { - return AT_DISPATCH_ALL_TYPES(CPU(src_type), "load", [&]() { + return AT_DISPATCH_ALL_TYPES(src_type, "load", [&]() { return at::convert(*(scalar_t*)data); }); } template inline void store(T value, void* dst, ScalarType dst_type) { - AT_DISPATCH_ALL_TYPES(CPU(dst_type), "store", [&]() { + AT_DISPATCH_ALL_TYPES(dst_type, "store", [&]() { *(scalar_t*)dst = at::convert(value); }); } diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index acbad68..5426c86 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -150,7 +150,7 @@ Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) { // case1: shared weight for all channels if (weight_num == 1) { - AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_cpu", [&] { prelu_cpu_kernel_share_weights(result, input, weight); }); } @@ -171,7 +171,7 @@ Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) { "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, " and channel size = ", channel_size, "."); - AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_cpu", [&] { prelu_cpu_kernel_multi_weights( result, input, @@ -277,7 +277,7 @@ std::tuple prelu_backward_cpu(const Tensor& grad_out_, const Ten // case1: shared parameter for all channels if (weight_num == 1) { - AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_backward_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_backward_cpu", [&] { prelu_cpu_backward_kernel_share_weights(input, weight, grad_out, input_grad, weight_grad); }); } @@ -298,7 +298,7 @@ std::tuple prelu_backward_cpu(const Tensor& grad_out_, const Ten "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, " and channel size = ", channel_size, "."); - AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_backward_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_backward_cpu", [&] { prelu_cpu_backward_kernel_multi_weights( input, weight, @@ -326,7 +326,7 @@ std::tuple prelu_backward_cpu(const Tensor& grad_out_, const Ten // ----------------------------------- Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) { auto out_tensor = at::empty_like(self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "hardshrink_cpu", [&] { auto lambd_val = lambd.to(); at::CPU_tensor_apply2( self, @@ -342,7 +342,7 @@ Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) { Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar lambd) { auto out_tensor = at::empty_like(self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_backward_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "hardshrink_backward_cpu", [&] { auto lambd_val = lambd.to(); at::CPU_tensor_apply3( self, diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp index a31211b..94314f5 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -102,7 +102,7 @@ namespace { { output.resize_({sizeD, osizeH, osizeW}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "adaptive_avg_pool2d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); adaptive_avg_pool2d_out_frame(input_data, output_data, @@ -121,7 +121,7 @@ namespace { #pragma omp parallel for private(b) for (b = 0; b < input.size(0); b++) { - AT_DISPATCH_FLOATING_TYPES(input.type(), "adaptive_avg_pool2d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); adaptive_avg_pool2d_out_frame(input_data+b*input.stride(0), output_data+b*sizeD*osizeH*osizeW, @@ -203,7 +203,7 @@ namespace { if (input.ndimension() == 3) { AT_DISPATCH_FLOATING_TYPES( - input.type(), "adaptive_avg_pool2d_backward", [&] { + input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { /* get raw pointers */ scalar_t *gradInput_data = gradInput.data(); scalar_t *gradOutput_data = gradOutput.data(); @@ -223,7 +223,7 @@ namespace { for (b = 0; b < input.size(0); b++) { AT_DISPATCH_FLOATING_TYPES( - input.type(), "adaptive_avg_pool2d_backward", [&] { + input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { /* get raw pointers */ scalar_t *gradInput_data = gradInput.data(); scalar_t *gradOutput_data = gradOutput.data(); @@ -262,7 +262,7 @@ namespace { return output; } - Tensor adaptive_avg_pool2d( + Tensor adaptive_avg_pool2d( at::Tensor const& input, IntArrayRef output_size){ if (output_size[0] == 1 && output_size[1] == 1) { diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index b373181..767368a 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -149,13 +149,13 @@ std::tuple _gesv_helper_cpu(const Tensor& self, const Tensor& A) auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); std::vector infos(batchCount(self), 0); - AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "gesv_cpu", [&]{ apply_gesv(self_working_copy, A_working_copy, infos); }); if (self.dim() > 2) { - batchCheckErrors(infos, "gesv"); + batchCheckErrors(infos, "gesv_cpu"); } else { - singleCheckErrors(infos[0], "gesv"); + singleCheckErrors(infos[0], "gesv_cpu"); } return std::tuple(self_working_copy, A_working_copy); } @@ -172,7 +172,7 @@ std::tuple gesv(const Tensor& self, const Tensor& A) { } std::tuple gesv_out(Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) { - AT_CHECK(self.dim() == 2 && A.dim() == 2, + AT_CHECK(self.dim() == 2 && A.dim() == 2, "torch.gesv() with the `out` keyword does not support batching. " "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2."); Tensor solution_tmp, lu_tmp; @@ -229,10 +229,10 @@ static void apply_inverse(Tensor& self, std::vector& infos) { Tensor _inverse_helper_cpu(const Tensor& self) { std::vector infos(batchCount(self), 0); auto self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "inverse", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cpu", [&]{ apply_inverse(self_working_copy, infos); }); - batchCheckErrors(infos, "inverse"); + batchCheckErrors(infos, "inverse_cpu"); return self_working_copy; } @@ -294,13 +294,13 @@ Tensor _cholesky_solve_helper_cpu(const Tensor& self, const Tensor& A, bool uppe auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); std::vector infos(batchCount(self), 0); - AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky_solve", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_solve_cpu", [&]{ apply_cholesky_solve(self_working_copy, A_working_copy, upper, infos); }); if (self.dim() > 2) { - batchCheckErrors(infos, "cholesky_solve"); + batchCheckErrors(infos, "cholesky_solve_cpu"); } else { - singleCheckErrors(infos[0], "cholesky_solve"); + singleCheckErrors(infos[0], "cholesky_solve_cpu"); } return self_working_copy; } @@ -358,13 +358,13 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector& infos Tensor _cholesky_helper_cpu(const Tensor& self, bool upper) { std::vector infos(batchCount(self), 0); auto self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_cpu", [&]{ apply_cholesky(self_working_copy, upper, infos); }); if (self.dim() > 2) { - batchCheckErrors(infos, "cholesky"); + batchCheckErrors(infos, "cholesky_cpu"); } else { - singleCheckErrors(infos[0], "cholesky"); + singleCheckErrors(infos[0], "cholesky_cpu"); } return self_working_copy; } @@ -474,7 +474,7 @@ Tensor& tril_cpu_(Tensor &self, int64_t k) { bool inplace = checkTrilTriuBatchContiguous(self); Tensor self_c = inplace ? self : self.contiguous(); Tensor result = inplace ? self : at::empty_like(self); - AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{ + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "tril", [&]{ apply_triu_tril(result, self_c, inplace, k); }); if (!inplace) self.copy_(result); @@ -489,7 +489,7 @@ Tensor& tril_cpu_out(Tensor &result, const Tensor& self, int64_t k) { return result; } Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous(); - AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{ + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "tril", [&]{ apply_triu_tril(result, self_c, false, k); }); return result; @@ -508,7 +508,7 @@ Tensor& triu_cpu_(Tensor &self, int64_t k) { bool inplace = checkTrilTriuBatchContiguous(self); Tensor self_c = inplace ? self : self.contiguous(); Tensor result = inplace ? self : at::empty_like(self); - AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{ + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "triu", [&]{ apply_triu_tril(result, self_c, inplace, k); }); if (!inplace) self.copy_(result); @@ -523,7 +523,7 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) { return result; } Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous(); - AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{ + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "triu", [&]{ apply_triu_tril(result, self_c, false, k); }); return result; diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index a8c7bab..0ff0b0f 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -20,7 +20,7 @@ void _copy__cpu(at::Tensor& self, const at::Tensor& src) { template void _copy__cpu(at::Tensor& self, const at::Tensor& src) { AT_CHECK(self.numel() == src.numel(), "sizes do not match"); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_copy__cpu", [&]() { _copy__cpu(self, src); }); } @@ -43,7 +43,7 @@ Tensor& _s_copy__cpu(Tensor& self, const Tensor& src, bool non_blocking) { return self; } AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, self.type(), "_copy__cpu", [&]() { ::_copy__cpu(self, src); }); + at::ScalarType::Half, self.scalar_type(), "_copy__cpu", [&]() { ::_copy__cpu(self, src); }); return self; } @@ -59,7 +59,7 @@ void _copy_same_type_transpose_(Tensor& self, const Tensor& src) { Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options()); AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, self.type(), "_copy_same_type_transpose_", [&]() { + at::ScalarType::Half, self.scalar_type(), "_copy_same_type_transpose_", [&]() { scalar_t* sp = src.data(); scalar_t* rp = self.data(); scalar_t* bp = buf.data(); @@ -115,7 +115,7 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) { #ifdef _OPENMP if (!in_parallel_region()) { AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() { + at::ScalarType::Half, self.scalar_type(), "_copy_same_type_", [&]() { at::CPU_tensor_parallel_apply2( self, src, [](scalar_t& self_val, const scalar_t& src_val) { self_val = src_val; @@ -134,7 +134,7 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) { if (serial_path) { AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() { + at::ScalarType::Half, self.scalar_type(), "_copy_same_type_", [&]() { at::CPU_tensor_apply2( self, src, [](scalar_t& self_val, const scalar_t& src_val) { self_val = src_val; diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 30ffb76..6ef1848 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -126,7 +126,7 @@ Tensor& bernoulli_out(Tensor& result, const Tensor& self, Generator* gen) { } Tensor& bernoulli_tensor_cpu_(Tensor& self, const Tensor& p_, Generator* gen) { - AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_tensor_cpu_self_", [&] { + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] { THGenerator* generator = get_generator(gen); std::lock_guard lock(generator->mutex); using self_t = scalar_t; @@ -137,7 +137,7 @@ Tensor& bernoulli_tensor_cpu_(Tensor& self, const Tensor& p_, Generator* gen) { ret_val = static_cast(THRandom_bernoulli(generator, p_val)); }); } else { - AT_DISPATCH_FLOATING_TYPES(p_.type(), "bernoulli_tensor_cpu_p_", [&] { + AT_DISPATCH_FLOATING_TYPES(p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] { auto p = std::get<0>(expand_inplace(self, p_.to(kCPU))); using p_t = scalar_t; CPU_tensor_apply2( @@ -160,7 +160,7 @@ Tensor& bernoulli_scalar_cpu_(Tensor& self, double p, Generator* gen) { return self; } #endif - AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_scalar_cpu_", [&] { + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "bernoulli_scalar_cpu_", [&] { THGenerator* generator = get_generator(gen); std::lock_guard lock(generator->mutex); CPU_tensor_apply1( @@ -174,7 +174,7 @@ Tensor& bernoulli_scalar_cpu_(Tensor& self, double p, Generator* gen) { Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) { Tensor ret = at::empty(self.sizes(), self.options()); - AT_DISPATCH_FLOATING_TYPES(self.type(), "_standard_gamma_grad", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "_standard_gamma_grad_cpu", [&] { CPU_tensor_apply3(ret, self, output, [](scalar_t& ret_val, const scalar_t& self_val, const scalar_t &output_val) { ret_val = standard_gamma_grad_one(self_val, output_val); @@ -190,7 +190,7 @@ Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) { Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) { Tensor ret = at::zeros(lambda.sizes(), lambda.options()); - AT_DISPATCH_FLOATING_TYPES(ret.type(), "poisson", [&] { + AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "poisson_cpu", [&] { THGenerator* generator = get_generator(gen); std::lock_guard lock(generator->mutex); CPU_tensor_apply2(ret, lambda, @@ -204,7 +204,7 @@ Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) { Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) { Tensor ret = at::zeros(alpha.sizes(), alpha.options()); - AT_DISPATCH_FLOATING_TYPES(ret.type(), "gamma", [&] { + AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "gamma_cpu", [&] { THGenerator* generator = get_generator(gen); std::lock_guard lock(generator->mutex); CPU_tensor_apply2(ret, alpha, diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 11e60e5..0ed1705 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -199,7 +199,7 @@ _embedding_bag_cpu(const Tensor &weight, const Tensor &indices, return std::tuple(ret, offset2bag, bag_size, bag_size); } else { // MODE_MAX return AT_DISPATCH_FLOATING_TYPES_AND_HALF( - weight.type(), "embedding_bag_cpu_max", [&]() { + weight.scalar_type(), "embedding_bag_cpu_max", [&]() { return embedding_bag_cpu_max(weight, indices, offset2bag, output, bag_size, offsets); } ); diff --git a/aten/src/ATen/native/FractionalMaxPool2d.cpp b/aten/src/ATen/native/FractionalMaxPool2d.cpp index 54e362c..1980306 100644 --- a/aten/src/ATen/native/FractionalMaxPool2d.cpp +++ b/aten/src/ATen/native/FractionalMaxPool2d.cpp @@ -178,7 +178,7 @@ void fractional_max_pool2d_out_cpu_template( indices.resize_({numBatch, numPlanes, outputH, outputW}); } - AT_DISPATCH_FLOATING_TYPES(input.type(), + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "fractional_max_pool2d_out_frame", [&] { auto input_data = input.data(); auto output_data = output.data(); @@ -295,7 +295,7 @@ Tensor& fractional_max_pool2d_backward_out_cpu_template( /* backprop */ AT_DISPATCH_FLOATING_TYPES( - input.type(), "fractional_max_pool2d_backward_out_frame", [&] { + input.scalar_type(), "fractional_max_pool2d_backward_out_frame", [&] { auto gradInput_data = gradInput.data(); auto gradOutput_data = gradOutput.data(); auto indices_data = indices.data(); diff --git a/aten/src/ATen/native/FractionalMaxPool3d.cpp b/aten/src/ATen/native/FractionalMaxPool3d.cpp index 72e6722..30dc2b4 100644 --- a/aten/src/ATen/native/FractionalMaxPool3d.cpp +++ b/aten/src/ATen/native/FractionalMaxPool3d.cpp @@ -201,7 +201,7 @@ void fractional_max_pool3d_out_cpu_template( indices.resize_({numBatch, numPlanes, outputT, outputH, outputW}); } AT_DISPATCH_FLOATING_TYPES( - input.type(), + input.scalar_type(), "fractional_max_pool3d_out_frame", [&] { fractional_max_pool3d_out_frame( @@ -330,7 +330,7 @@ void fractional_max_pool3d_backward_out_cpu_template( /* backprop */ AT_DISPATCH_FLOATING_TYPES( - input.type(), + input.scalar_type(), "fractional_max_pool3d_backward_out_frame", [&]{ fractional_max_pool3d_backward_out_frame( diff --git a/aten/src/ATen/native/GridSampler.cpp b/aten/src/ATen/native/GridSampler.cpp index e9bb623..65af517 100644 --- a/aten/src/ATen/native/GridSampler.cpp +++ b/aten/src/ATen/native/GridSampler.cpp @@ -534,7 +534,7 @@ DEFINE_DISPATCH(grid_sampler_2d_cpu_kernel); // No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode) { - return AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sampler3d_cpu", [&] { + return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler3d_cpu", [&] { return grid_sampler_3d_cpu_impl( input, grid, static_cast(interpolation_mode), static_cast(padding_mode)); @@ -554,7 +554,7 @@ DEFINE_DISPATCH(grid_sampler_2d_backward_cpu_kernel); std::tuple grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode) { - return AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sampler_3d_backward_cpu", [&] { + return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] { return grid_sampler_3d_backward_cpu_impl( grad_output, input, grid, static_cast(interpolation_mode), diff --git a/aten/src/ATen/native/Lerp.cpp b/aten/src/ATen/native/Lerp.cpp index a541cfe..96f9534 100644 --- a/aten/src/ATen/native/Lerp.cpp +++ b/aten/src/ATen/native/Lerp.cpp @@ -38,9 +38,9 @@ Tensor& lerp_cpu_tensor_out(Tensor& result, const Tensor& self, Tensor b_self, b_end, b_weight; AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), "weight should be of dimension max(self.dim(), end.dim()) or lesser"); - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out"); + std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out_cpu"); result.resize_as_(b_self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_out", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{ lerp_cpu(result, b_self, b_end, b_weight); }); return result; @@ -49,9 +49,9 @@ Tensor& lerp_cpu_tensor_out(Tensor& result, const Tensor& self, Tensor& lerp_cpu_scalar_out(Tensor& result, const Tensor& self, const Tensor& end, Scalar weight) { Tensor b_self, b_end; - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out"); + std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out_cpu"); result.resize_as_(b_self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_out", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{ lerp_cpu(result, b_self, b_end, weight.to()); }); return result; @@ -59,13 +59,13 @@ Tensor& lerp_cpu_scalar_out(Tensor& result, const Tensor& self, Tensor& lerp_cpu_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) { Tensor b_self, b_end, b_weight; - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_"); + std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp__cpu"); AT_CHECK(b_self.sizes() == self.sizes(), "output with shape ", self.sizes(), " doesn't match the broadcast shape ", b_self.sizes()); AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), "weight should be of dimension max(self.dim(), end.dim()) or lesser"); - AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{ lerp_cpu(self, b_self, b_end, b_weight); }); return self; @@ -73,11 +73,11 @@ Tensor& lerp_cpu_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) Tensor& lerp_cpu_scalar_(Tensor& self, const Tensor& end, Scalar weight) { Tensor b_self, b_end; - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_"); + std::tie(b_self, b_end) = expand_outplace(self, end, "lerp__cpu"); AT_CHECK(b_self.sizes() == self.sizes(), "output with shape ", self.sizes(), " doesn't match the broadcast shape ", b_self.sizes()); - AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{ lerp_cpu(self, b_self, b_end, weight.to()); }); return self; @@ -87,9 +87,9 @@ Tensor lerp_cpu_tensor(const Tensor& self, const Tensor& end, const Tensor& weig Tensor b_self, b_end, b_weight; AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), "weight should be of dimension max(self.dim(), end.dim()) or lesser"); - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp"); + std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_cpu"); Tensor result = at::empty_like(b_self); - AT_DISPATCH_FLOATING_TYPES(result.type(), "lerp", [&]{ + AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{ lerp_cpu(result, b_self, b_end, b_weight); }); return result; @@ -97,9 +97,9 @@ Tensor lerp_cpu_tensor(const Tensor& self, const Tensor& end, const Tensor& weig Tensor lerp_cpu_scalar(const Tensor& self, const Tensor& end, Scalar weight) { Tensor b_self, b_end; - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp"); + std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_cpu"); Tensor result = at::empty_like(b_self); - AT_DISPATCH_FLOATING_TYPES(result.type(), "lerp", [&]{ + AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{ lerp_cpu(result, b_self, b_end, weight.to()); }); return result; diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 42cad7c..c39fcbb 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -299,11 +299,11 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& if (contraction_size * res_rows * res_cols < 400) { if (is_bmm_out) { - AT_DISPATCH_ALL_TYPES(batch1.type(), "bmm", [&] { + AT_DISPATCH_ALL_TYPES(batch1.scalar_type(), "bmm", [&] { baddbmm_cpu_kernel(self_or_result, batch1, batch2, beta, alpha); }); } else { - AT_DISPATCH_ALL_TYPES(batch1.type(), "baddbmm", [&] { + AT_DISPATCH_ALL_TYPES(batch1.scalar_type(), "baddbmm", [&] { baddbmm_cpu_kernel(self_or_result, batch1, batch2, beta, alpha); }); } diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 6722d85..7def0da 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -70,7 +70,7 @@ Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction) { Tensor kl_div_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction) { auto grad_input = at::zeros_like(input); auto grad_expand = grad.expand_as(input); - AT_DISPATCH_FLOATING_TYPES(input.type(), "kl_div_backward", [&]() { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "kl_div_backward_cpu", [&]() { at::CPU_tensor_apply3( grad_input, target, diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index 5b7f376..f6d8906 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -307,7 +307,7 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_ std::tuple ctc_loss_cpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) { (void)zero_infinity; // only used for backwards - return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss", [&] { + return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_cpu", [&] { if (targets.scalar_type() == kLong) { return ctc_loss_cpu_template(log_probs, targets, input_lengths, target_lengths, BLANK); } else { @@ -318,7 +318,7 @@ std::tuple ctc_loss_cpu(const Tensor& log_probs, const Tensor& t Tensor ctc_loss_backward_cpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) { - return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss_backward", [&] { + return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_backward_cpu", [&] { if (targets.scalar_type() == kLong) { return ctc_loss_backward_cpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity); } else { diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index f62bd4b..e4be451 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -486,7 +486,7 @@ Tensor group_norm(const Tensor& input, int64_t num_groups, std::tuple batch_norm_update_stats_cpu( const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) { - return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_update_stats", [&] { + return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm_update_stats_cpu", [&] { return batch_norm_cpu_update_stats_template(self, running_mean, running_var, momentum, 0); }); } @@ -496,7 +496,7 @@ std::tuple batch_norm_cpu(const Tensor& self, const Tens bool train, double momentum, double eps) { checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU); - return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] { + return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm", [&] { if (!train) { return batch_norm_cpu_transform_input_template(self, weight, bias, {}, {}, running_mean, running_var, train, eps); } else { @@ -509,7 +509,7 @@ std::tuple batch_norm_cpu(const Tensor& self, const Tens std::tuple batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, bool train, double eps, std::array grad_input_mask) { - return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_backward", [&] { + return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm_backward_cpu", [&] { return batch_norm_backward_cpu_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask); }); } diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp index 2f71149..02abdab 100644 --- a/aten/src/ATen/native/RangeFactories.cpp +++ b/aten/src/ATen/native/RangeFactories.cpp @@ -21,7 +21,7 @@ Tensor& linspace_cpu_out(Tensor& result, Scalar start, Scalar end, int64_t steps } else if (steps == 1) { r.fill_(start); } else { - AT_DISPATCH_FLOATING_TYPES(r.type(), "linspace", [&]() { + AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "linspace_cpu", [&]() { scalar_t scalar_start = start.to(); scalar_t scalar_end = end.to(); scalar_t *data_ptr = r.data(); @@ -54,7 +54,7 @@ Tensor& logspace_cpu_out(Tensor& result, Scalar start, Scalar end, int64_t steps } else if (steps == 1) { r.fill_(std::pow(10.0, start.to())); } else { - AT_DISPATCH_FLOATING_TYPES(r.type(), "logspace", [&]() { + AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "logspace_cpu", [&]() { scalar_t base10 = 10; scalar_t scalar_start = start.to(); scalar_t scalar_end = end.to(); @@ -76,7 +76,7 @@ Tensor& logspace_cpu_out(Tensor& result, Scalar start, Scalar end, int64_t steps } Tensor& range_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) { - AT_DISPATCH_ALL_TYPES(result.type(), "range", [&]() { + AT_DISPATCH_ALL_TYPES(result.scalar_type(), "range_cpu", [&]() { using accscalar_t = at::acc_type; auto xstart = start.to(); auto xend = end.to(); @@ -110,7 +110,7 @@ Tensor& range_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) { } Tensor& arange_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) { - AT_DISPATCH_ALL_TYPES(result.type(), "arange", [&]() { + AT_DISPATCH_ALL_TYPES(result.scalar_type(), "arange_cpu", [&]() { using accscalar_t = at::acc_type; auto xstart = start.to(); auto xend = end.to(); diff --git a/aten/src/ATen/native/ReflectionPad.cpp b/aten/src/ATen/native/ReflectionPad.cpp index 0badbe3..b8ee136 100644 --- a/aten/src/ATen/native/ReflectionPad.cpp +++ b/aten/src/ATen/native/ReflectionPad.cpp @@ -91,7 +91,7 @@ void reflection_pad1d_out_template( /* resize output */ if (input.ndimension() == 2) { output.resize_({nplane, output_w}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad1d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad1d", [&] { reflection_pad1d_out_frame( input.data(), output.data(), nplane, @@ -100,7 +100,7 @@ void reflection_pad1d_out_template( }); } else { output.resize_({nbatch, nplane, output_w}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad1d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad1d", [&] { reflection_pad1d_out_loop( input.data(), output.data(), nbatch, nplane, @@ -187,7 +187,7 @@ void reflection_pad1d_backward_out_template( /* backprop */ if (input.ndimension() == 2) { AT_DISPATCH_FLOATING_TYPES( - grad_input.type(), "reflection_pad1d_backward", [&] { + grad_input.scalar_type(), "reflection_pad1d_backward", [&] { reflection_pad1d_backward_out_frame( grad_input.data(), grad_output.data(), nplane, @@ -197,7 +197,7 @@ void reflection_pad1d_backward_out_template( ); } else { AT_DISPATCH_FLOATING_TYPES( - grad_input.type(), "reflection_pad1d_backward", [&] { + grad_input.scalar_type(), "reflection_pad1d_backward", [&] { reflection_pad1d_backward_out_loop( grad_input.data(), grad_output.data(), @@ -322,7 +322,7 @@ void reflection_pad2d_out_template( if (input.ndimension() == 3) { /* resize output */ output.resize_({nplane, output_h, output_w}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad2d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad2d", [&] { reflection_pad2d_out_frame( input.data(), output.data(), nplane, @@ -332,7 +332,7 @@ void reflection_pad2d_out_template( } else { /* resize output */ output.resize_({nbatch, nplane, output_h, output_w}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad2d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad2d", [&] { reflection_pad2d_out_loop( input.data(), output.data(), nbatch, nplane, @@ -448,7 +448,7 @@ void reflection_pad2d_backward_out_template( /* backprop */ if (input.ndimension() == 3) { AT_DISPATCH_FLOATING_TYPES( - grad_output.type(), "reflection_pad2d_backward", [&] { + grad_output.scalar_type(), "reflection_pad2d_backward", [&] { reflection_pad2d_backward_out_frame( grad_input.data(), grad_output.data(), nplane, @@ -458,7 +458,7 @@ void reflection_pad2d_backward_out_template( ); } else { AT_DISPATCH_FLOATING_TYPES( - grad_output.type(), "reflection_pad2d_backward", [&] { + grad_output.scalar_type(), "reflection_pad2d_backward", [&] { reflection_pad2d_backward_out_loop( grad_input.data(), grad_output.data(), nbatch, nplane, diff --git a/aten/src/ATen/native/ReplicationPadding.cpp b/aten/src/ATen/native/ReplicationPadding.cpp index 969dda5..0431cd0 100644 --- a/aten/src/ATen/native/ReplicationPadding.cpp +++ b/aten/src/ATen/native/ReplicationPadding.cpp @@ -97,7 +97,7 @@ void replication_pad1d_out_cpu_template( if (input.ndimension() == 2) { output.resize_({nslices, owidth}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad1d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad1d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); replication_pad1d_out_frame( @@ -113,7 +113,7 @@ void replication_pad1d_out_cpu_template( else { output.resize_({nbatch, nslices, owidth}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad1d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad1d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); replication_pad1d_out_batch( @@ -219,7 +219,7 @@ Tensor& replication_pad1d_backward_out_cpu_template( if (input.ndimension() == 2) { AT_DISPATCH_FLOATING_TYPES( - input.type(), "replication_pad1d_backward", [&] { + input.scalar_type(), "replication_pad1d_backward_cpu", [&] { scalar_t *gradInput_data = gradInput.data(); scalar_t *gradOutput_data = gradOutput.data(); @@ -236,7 +236,7 @@ Tensor& replication_pad1d_backward_out_cpu_template( else { AT_DISPATCH_FLOATING_TYPES( - input.type(), "replication_pad1d_backward", [&] { + input.scalar_type(), "replication_pad1d_backward_cpu", [&] { scalar_t *gradInput_data = gradInput.data(); scalar_t *gradOutput_data = gradOutput.data(); @@ -365,7 +365,7 @@ void replication_pad2d_out_cpu_template(Tensor& output, if (input.dim() == 3) { output.resize_({nslices, oheight, owidth}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad2d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad2d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); replication_pad2d_out_frame (input_data, output_data, @@ -380,7 +380,7 @@ void replication_pad2d_out_cpu_template(Tensor& output, else { output.resize_({nbatch, nslices, oheight, owidth}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad2d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad2d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); replication_pad2d_out_batch (input_data, output_data, @@ -511,7 +511,7 @@ Tensor& replication_pad2d_backward_out_cpu_template( if (input.dim() == 3) { AT_DISPATCH_FLOATING_TYPES( - input.type(), "replication_pad2d_backward", [&] { + input.scalar_type(), "replication_pad2d_backward_cpu", [&] { replication_pad2d_backward_out_frame( gradInput.data(), gradOutput.data(), @@ -526,7 +526,7 @@ Tensor& replication_pad2d_backward_out_cpu_template( else { AT_DISPATCH_FLOATING_TYPES( - input.type(), "replication_pad2d_backward", [&] { + input.scalar_type(), "replication_pad2d_backward_cpu", [&] { replication_pad2d_backward_out_batch( gradInput.data(), gradOutput.data(), @@ -709,7 +709,7 @@ void replication_pad3d_out_cpu_template( if (input.dim() == 4) { output.resize_({nslices, odepth, oheight, owidth}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad3d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad3d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); replication_pad3d_out_frame( @@ -722,7 +722,7 @@ void replication_pad3d_out_cpu_template( else { output.resize_({nbatch, nslices, odepth, oheight, owidth}); - AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad3d", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad3d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); replication_pad3d_out_batch( @@ -871,7 +871,7 @@ Tensor& replication_pad3d_backward_out_cpu_template( if (input.dim() == 4) { AT_DISPATCH_FLOATING_TYPES( - input.type(), "replication_pad3d_backward", [&] { + input.scalar_type(), "replication_pad3d_backward_cpu", [&] { replication_pad3d_backward_out_frame ( gradInput.data(), gradOutput.data(), @@ -887,7 +887,7 @@ Tensor& replication_pad3d_backward_out_cpu_template( else { AT_DISPATCH_FLOATING_TYPES( - input.type(), "replication_pad3d_backward", [&] { + input.scalar_type(), "replication_pad3d_backward_cpu", [&] { replication_pad3d_backward_out_batch ( gradInput.data(), gradOutput.data(), diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index c635c9d..94b8df9 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -19,7 +19,7 @@ Scalar item(const Tensor& self) { Scalar _local_scalar_dense_cpu(const Tensor& self) { Scalar r; AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND( - at::ScalarType::Half, self.type(), "_local_scalar_dense_cpu", [&] { + at::ScalarType::Half, self.scalar_type(), "_local_scalar_dense_cpu", [&] { scalar_t value = *self.data(); r = Scalar(value); }); diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index f34c3e1..64d259f 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -132,7 +132,7 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half_to_ if (input.ndimension() > 0 && dim == input.ndimension() - 1) { softmax_lastdim_kernel(kCPU, output, input); } else { - AT_DISPATCH_FLOATING_TYPES(input.type(), "softmax", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "softmax", [&] { host_softmax(output, input, dim); }); } @@ -152,7 +152,7 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half if (input.ndimension() > 0 && dim == input.ndimension() - 1) { log_softmax_lastdim_kernel(kCPU, output, input); } else { - AT_DISPATCH_FLOATING_TYPES(input.type(), "log_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_softmax", [&] { host_softmax(output, input, dim); }); } @@ -181,7 +181,7 @@ Tensor softmax_backward_cpu( if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) { softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output); } else { - AT_DISPATCH_FLOATING_TYPES(grad.type(), "softmax_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] { host_softmax_backward(grad_input, grad, output, dim); }); } @@ -210,7 +210,7 @@ Tensor log_softmax_backward_cpu( if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) { log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output); } else { - AT_DISPATCH_FLOATING_TYPES(grad.type(), "log_softmax_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] { host_softmax_backward(grad_input, grad, output, dim); }); } diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 4bc3c2c..9f52b2a 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -153,7 +153,7 @@ std::tuple kthvalue_out_cpu( } auto tmp_values = self.clone(); auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong)); - AT_DISPATCH_ALL_TYPES(self.type(), "kthvalue", [&] { + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "kthvalue_cpu", [&] { dim_apply( {tmp_values, tmp_indices, values, indices}, dim, diff --git a/aten/src/ATen/native/SummaryOps.cpp b/aten/src/ATen/native/SummaryOps.cpp index 10d7b7b..976dcd7 100644 --- a/aten/src/ATen/native/SummaryOps.cpp +++ b/aten/src/ATen/native/SummaryOps.cpp @@ -55,7 +55,7 @@ Tensor _bincount_cpu_template( Tensor _bincount_cpu(const Tensor& self, const Tensor& weights, int64_t minlength) { - return AT_DISPATCH_INTEGRAL_TYPES(self.type(), "bincount", [&] { + return AT_DISPATCH_INTEGRAL_TYPES(self.scalar_type(), "bincount_cpu", [&] { const auto scalar = weights.scalar_type(); if (scalar == ScalarType::Undefined || scalar == ScalarType::Float) return _bincount_cpu_template(self.contiguous(), weights.contiguous(), minlength); diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index e59988b..c5360b5 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -91,7 +91,7 @@ Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) { Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) { Tensor ret = at::empty(self.sizes(), self.options()); - AT_DISPATCH_ALL_TYPES(ret.type(), "where", [&] { + AT_DISPATCH_ALL_TYPES(ret.scalar_type(), "where_cpu", [&] { where_cpu(ret, condition, self, other); }); return ret; diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1e712ad..b5d6ec8 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -184,7 +184,7 @@ Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) { result.zero_(); int64_t sz = std::min(n, m); - AT_DISPATCH_ALL_TYPES(result.type(), "eye", [&]() -> void { + AT_DISPATCH_ALL_TYPES(result.scalar_type(), "eye", [&]() -> void { scalar_t* result_data = result.data(); for(int64_t i = 0; i < sz; i++) { result_data[i*(result.strides()[0] + result.strides()[1])] = 1; @@ -453,7 +453,7 @@ Tensor& randperm_out_cpu(Tensor& result, int64_t n, Generator* generator) { AT_CHECK(n >= 0, "n must be non-negative, got", n); result.resize_({n}); auto gen = get_generator(generator); - AT_DISPATCH_ALL_TYPES(result.type(), "randperm", [&]() -> void { + AT_DISPATCH_ALL_TYPES(result.scalar_type(), "randperm", [&]() -> void { randperm_cpu(result, n, gen); }); @@ -501,7 +501,7 @@ Tensor tril_indices_cpu( // // 3. sequential RAM + transpose: create an n X 2 Tensor, fill the Tensor // sequentially, and then transpose it. - AT_DISPATCH_ALL_TYPES(result.type(), "tril_indices", [&]() -> void { + AT_DISPATCH_ALL_TYPES(result.scalar_type(), "tril_indices", [&]() -> void { // fill the Tensor with correct values scalar_t* result_data = result.data(); int64_t i = 0; @@ -534,7 +534,7 @@ Tensor triu_indices_cpu( // create an empty Tensor with correct size auto result = at::empty({2, triu_size}, options); - AT_DISPATCH_ALL_TYPES(result.type(), "triu_indices", [&]() -> void { + AT_DISPATCH_ALL_TYPES(result.scalar_type(), "triu_indices", [&]() -> void { // fill the Tensor with correct values scalar_t* result_data = result.data(); int64_t i = 0; @@ -705,7 +705,7 @@ template Tensor tensor_cpu(ArrayRef values, const TensorOptions& options) { auto result = at::empty(values.size(), options); AT_ASSERT(result.is_contiguous()); - AT_DISPATCH_ALL_TYPES(result.type(), "tensor_cpu", [&] { + AT_DISPATCH_ALL_TYPES(result.scalar_type(), "tensor_cpu", [&] { std::copy(values.begin(), values.end(), result.template data()); }); return result; diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h index 7d79d66..274b5e9 100644 --- a/aten/src/ATen/native/TensorIterator.h +++ b/aten/src/ATen/native/TensorIterator.h @@ -151,7 +151,7 @@ struct CAFFE2_API TensorIterator { AT_ASSERT(operands_[arg].type); return *operands_[arg].type; } - ScalarType dtype(int arg) const { return type(arg).scalarType(); } + ScalarType dtype(int arg=0) const { return type(arg).scalarType(); } DeviceType device_type(int arg=0) const { return type(arg).device_type(); } int64_t element_size(int arg) const { return type(arg).elementSizeInBytes(); } bool is_scalar(int arg) const; diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 9f6f0b4..fcb3d0a 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -60,7 +60,7 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) { } } - AT_DISPATCH_ALL_TYPES(in_tensor.type(), "flip_cpu", [&] { + AT_DISPATCH_ALL_TYPES(in_tensor.scalar_type(), "flip_cpu", [&] { flip_cpu_kernel( total_dims, stride_contiguous_v, diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index 398c599..9ed7648 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -25,7 +25,7 @@ bool is_signed(const Tensor &self) { if (self.scalar_type() == ScalarType::Half) { return true; } - return AT_DISPATCH_ALL_TYPES(self.type(), "is_signed", [&]() -> bool { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "is_signed", [&]() -> bool { return std::is_signed(); }); } diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index 8f6cfae..8cc867f 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -127,14 +127,14 @@ std::tuple _unique_dim_cpu_template( std::tuple _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) { - return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] { return _unique_cpu_template(self, sorted, return_inverse); }); } std::tuple _unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { - return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { // The current implementation using `dim` always sorts due to unhashable tensors return _unique_dim_cpu_template(self, dim, return_inverse); }); diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index 1d6e75c..a2e7bd3 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -9,7 +9,7 @@ namespace at { namespace native { namespace { static void threshold_kernel(TensorIterator& iter, Scalar threshold_scalar, Scalar value_scalar) { - AT_DISPATCH_ALL_TYPES(iter.type(), "threshold", [&] { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "threshold_cpu", [&] { using Vec = Vec256; scalar_t threshold = threshold_scalar.to(); scalar_t value = value_scalar.to(); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 2ba95f7..431a6dc 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -14,7 +14,7 @@ namespace { using namespace vec256; void add_kernel(TensorIterator& iter, Scalar alpha_scalar) { - AT_DISPATCH_ALL_TYPES(iter.type(), "add", [&]() { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_cpu", [&]() { auto alpha = alpha_scalar.to(); auto alpha_vec = Vec256(alpha); binary_kernel_vec(iter, @@ -30,7 +30,7 @@ void sub_kernel(TensorIterator& iter, Scalar alpha_scalar) { } void mul_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.type(), "mul", [&]() { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "mul_cpu", [&]() { binary_kernel_vec(iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }, [=](Vec256 a, Vec256 b) { @@ -40,16 +40,16 @@ void mul_kernel(TensorIterator& iter) { } void div_kernel(TensorIterator& iter) { - if (isIntegralType(iter.type().scalarType())) { + if (isIntegralType(iter.dtype())) { // There's no SIMD integer division, so don't try to vectorize it. // TODO: if the divisor is a scalar, rewrite as multiplication by a constant. - AT_DISPATCH_INTEGRAL_TYPES(iter.type(), "div", [&]() { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_cpu", [&]() { binary_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return a / b; }); }); } else { - AT_DISPATCH_FLOATING_TYPES(iter.type(), "div", [&]() { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "div_cpu", [&]() { binary_kernel_vec(iter, [=](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return a / b; diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index 03ef8df..d8cf81d 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -15,7 +15,7 @@ constexpr int64_t COPY_GRAIN_SIZE = 20000; static void copy_kernel_impl(Tensor& dst, const Tensor& src) { AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, dst.type(), "copy_kernel_impl", [&]() { + at::ScalarType::Half, dst.scalar_type(), "copy_kernel_impl", [&]() { scalar_t* self_ptr = dst.data(); scalar_t* src_ptr = src.data(); diff --git a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp index f9ff60c..237d907 100644 --- a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp @@ -270,19 +270,19 @@ struct PDist { }; void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) { - AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist", [&] { PDist::apply(result, self, p); }); } static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) { - AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_backward", [&] { PDist::apply_backward(result, grad, self, p, dist); }); } static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) { - AT_DISPATCH_FLOATING_TYPES(result.type(), "cdist", [&] { + AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist", [&] { PDist::apply_cdist(result, x1, x2, p); }); } diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 7faff9b..d96589d 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -308,8 +308,8 @@ static inline void mask_scatter_add(const scalar_t *src, scalar_t* base_addr, const int_same_size_t *offsets, const int_same_size_t *mask, int64_t len) { - #ifndef _MSC_VER - # pragma unroll + #ifndef _MSC_VER + # pragma unroll #endif for (int64_t i = 0; i < len; i++) { if (mask[i] & 0x01) { @@ -431,8 +431,8 @@ struct ApplyGridSample auto i_sw_offset = i_nw_offset + iVec(inp_sH); auto i_se_offset = i_sw_offset + iVec(inp_sW); - #ifndef _MSC_VER - # pragma unroll + #ifndef _MSC_VER + # pragma unroll #endif for (int64_t c = 0; c < C; ++c) { auto inp_slice_C_ptr = inp_slice[c].data(); @@ -505,8 +505,8 @@ struct ApplyGridSample scalar_t gInp_corner_arr[Vec::size()]; auto gx = Vec(0), gy = Vec(0); - #ifndef _MSC_VER - # pragma unroll + #ifndef _MSC_VER + # pragma unroll #endif for (int64_t c = 0; c < C; ++c) { auto inp_slice_C_ptr = inp_slice[c].data(); @@ -598,8 +598,8 @@ struct ApplyGridSample auto out_ptr = out_slice.data() + offset; auto out_sC = out_slice.stride(0); auto inp_slice_ptr = inp_slice.data(); - #ifndef _MSC_VER - # pragma unroll + #ifndef _MSC_VER + # pragma unroll #endif for (int c = 0; c < C; ++c, out_ptr += out_sC, inp_slice_ptr += inp_sC) { // mask_gather zeros out the mask, so we need to make a copy @@ -635,8 +635,8 @@ struct ApplyGridSample integer_t gInp_offset_arr[iVec::size()]; i_gInp_offset.store(gInp_offset_arr); - #ifndef _MSC_VER - # pragma unroll + #ifndef _MSC_VER + # pragma unroll #endif for (int64_t c = 0; c < C; ++c) { mask_scatter_add(gOut_slice[c].data() + offset, gInp_slice[c].data(), @@ -743,15 +743,15 @@ static inline void grid_sample_2d_grid_slice_iterator( auto spatial_offset = 0; auto i_offsets_delta = iVec(grid_sW * step); - #ifndef _MSC_VER - # pragma unroll + #ifndef _MSC_VER + # pragma unroll #endif for (int64_t h = 0; h < out_H; h++) { auto grid_ptr_x = grid_ptr + h * grid_sH; auto grid_ptr_y = grid_ptr_x + grid_sCoor; auto i_offsets = iVec::arange(0, grid_sW); - #ifndef _MSC_VER - # pragma unroll + #ifndef _MSC_VER + # pragma unroll #endif for (int64_t w = 0; w < out_W; w += step) { auto len = std::min(step, out_W - w); @@ -815,7 +815,7 @@ Tensor grid_sampler_2d_cpu_kernel_impl(const Tensor& input, const Tensor& grid, return; \ } - AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sampler_2d_cpu_kernel_impl", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_cpu_kernel_impl", [&] { auto out_acc = output.accessor(); auto inp_acc = input.accessor(); auto grid_acc = grid.accessor(); @@ -878,7 +878,7 @@ grid_sampler_2d_backward_cpu_kernel_impl(const Tensor& grad_output_, return; \ } - AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sampler_2d_backward_cpu_kernel_impl", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_backward_cpu_kernel_impl", [&] { auto gInp_acc = grad_input.accessor(); auto gGrid_acc = grad_grid.accessor(); auto inp_acc = input.accessor(); diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index 0d93112..698b40b 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -92,7 +92,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef } void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cpu", [&] { cpu_index_kernel(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) { *(scalar_t*)dst = *(scalar_t*)(src + offset); }); @@ -101,7 +101,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { // NOTE: duplicate indices are only supported if accumulate is true. - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index_put", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] { if (accumulate) { // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case, // this needs to be thread-safe. diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 2f78b59..b895636 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -15,7 +15,7 @@ namespace at { namespace native { namespace { using namespace vec256; static void sum_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.type(), "sum", [&] { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "sum_cpu", [&] { binary_kernel_reduce_vec( iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; }, @@ -24,7 +24,7 @@ static void sum_kernel_impl(TensorIterator& iter) { } static void mean_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.type(), "mean", [&] { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cpu", [&] { scalar_t factor = scalar_t(iter.num_output_elements()) / iter.numel(); binary_kernel_reduce( iter, @@ -35,7 +35,7 @@ static void mean_kernel_impl(TensorIterator& iter) { } static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_sqrt) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "std_cpu", [&] { binary_kernel_reduce( iter, WelfordOps { unbiased, take_sqrt }, @@ -45,7 +45,7 @@ static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_s } static void prod_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.type(), "prod", [&] { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "prod_cpu", [&] { binary_kernel_reduce_vec( iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }, @@ -68,7 +68,7 @@ static void norm_kernel_tensor_iterator_impl( if (val == 0) { - AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, NormZeroOps(), @@ -76,7 +76,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else if (val == 1) { - AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, NormOneOps(), @@ -84,7 +84,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else if (val == INFINITY) { - AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, AbsMaxOps(), @@ -92,7 +92,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else if (val == -INFINITY) { - AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, AbsMinOps(), @@ -100,7 +100,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else { - AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, NormOps { scalar_t(val) }, @@ -149,7 +149,7 @@ static void or_kernel_impl(TensorIterator& iter) { } static void min_values_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&iter] { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cpu", [&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return std::min(a, b); }, @@ -158,7 +158,7 @@ static void min_values_kernel_impl(TensorIterator& iter) { } static void max_values_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&iter] { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cpu", [&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return std::max(a, b); }, diff --git a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp index 83f0232..d838b86 100644 --- a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp +++ b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp @@ -218,7 +218,7 @@ struct vec_host_softmax_backward_lastdim { }; static void softmax_lastdim_kernel_impl(Tensor& result, const Tensor& self) { - AT_DISPATCH_FLOATING_TYPES(self.type(), "softmax_lastdim_kernel_impl", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "softmax_lastdim_kernel_impl", [&] { vec_host_softmax_lastdim::apply(result, self); }); } @@ -227,7 +227,7 @@ static void log_softmax_lastdim_kernel_impl( Tensor& result, const Tensor& self) { AT_DISPATCH_FLOATING_TYPES( - self.type(), "log_softmax_lastdim_kernel_impl", [&] { + self.scalar_type(), "log_softmax_lastdim_kernel_impl", [&] { vec_host_softmax_lastdim::apply(result, self); }); } @@ -237,7 +237,7 @@ static void softmax_backward_lastdim_kernel_impl( const Tensor& grad, const Tensor& output) { AT_DISPATCH_FLOATING_TYPES( - grad.type(), "softmax_backward_lastdim_kernel_impl", [&] { + grad.scalar_type(), "softmax_backward_lastdim_kernel_impl", [&] { vec_host_softmax_backward_lastdim::apply( grad_input, grad, output); }); @@ -248,7 +248,7 @@ static void log_softmax_backward_lastdim_kernel_impl( const Tensor& grad, const Tensor& output) { AT_DISPATCH_FLOATING_TYPES( - grad.type(), "log_softmax_backward_lastdim_kernel_impl", [&] { + grad.scalar_type(), "log_softmax_backward_lastdim_kernel_impl", [&] { vec_host_softmax_backward_lastdim::apply( grad_input, grad, output); }); diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 5dc8603..d118e59 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -97,7 +97,7 @@ static void max_kernel_impl( Tensor& max_indices, const Tensor& self, c10::optional dim) { - AT_DISPATCH_ALL_TYPES(self.type(), "max", [&] { + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "max", [&] { Reduction::apply(max, max_indices, self, dim, true); }); } @@ -107,7 +107,7 @@ static void min_kernel_impl( Tensor& min_indices, const Tensor& self, c10::optional dim) { - AT_DISPATCH_ALL_TYPES(self.type(), "min", [&] { + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "min", [&] { Reduction::apply(min, min_indices, self, dim, false); }); } diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 31263fd..aaa566c 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -81,7 +81,7 @@ int64_t _sigmoid(double* x, double* y, int64_t size) { } static void sigmoid_kernel(Tensor& result, const Tensor& self) { - AT_DISPATCH_FLOATING_TYPES(self.type(), "sigmoid", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "sigmoid", [&] { using Vec = Vec256; CPU_tensor_parallel_kernel_apply2( result, @@ -133,7 +133,7 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) { int64_t n = self.numel(); bool contig = self.is_contiguous(); - AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_scalar_cpu_", [&] { + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "bernoulli_scalar_cpu_", [&] { at::Tensor tmp_int_tensor; if (std::is_same::value && contig) { tmp_int_tensor = self; @@ -177,7 +177,7 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) { #define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op) \ static void op##_kernel(Tensor& result, const Tensor& self) { \ checkBackend(#op, {result}, Backend::CPU); \ - AT_DISPATCH_##dispatchtypes##_TYPES(self.type(), #op, [&] { \ + AT_DISPATCH_##dispatchtypes##_TYPES(self.scalar_type(), #op, [&] { \ if (self.is_contiguous() && result.is_contiguous()) { \ vml::v##op( \ result.data(), self.data(), self.numel()); \ diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index 83e879a..e0c2148 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -62,7 +62,7 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) { // case1: shared weight for all channels if (weight_num == 1) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "prelu_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "prelu_cuda", [&] { prelu_cuda_kernel_share_weights( input, result, @@ -94,7 +94,7 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); AT_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu: input too large or too many dimensions"); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "prelu_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "prelu_cuda", [&] { prelu_cuda_kernel_multi_weights <<>>( result.data(), @@ -175,7 +175,7 @@ std::tuple prelu_backward_cuda(const Tensor& grad_out_, const Te Tensor weight_grad_collector = at::empty_like(input); // case1: shared parameter for all channels if (weight_num == 1) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "prelu_backward_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "prelu_backward_cuda", [&] { prelu_cuda_backward_kernel_share_weights( input, grad_out, @@ -210,7 +210,7 @@ std::tuple prelu_backward_cuda(const Tensor& grad_out_, const Te cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); AT_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu_backward_cuda: input too large or too many dimensions"); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "prelu_backward_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "prelu_backward_cuda", [&] { prelu_cuda_backward_kernel_multi_weights <<>>( input.data(), @@ -264,7 +264,7 @@ void hardshrink_backward_cuda_kernel(const Tensor& self, Tensor& out_tensor, sca Tensor hardshrink_cuda(const Tensor & self, Scalar lambd) { auto out_tensor = at::empty_like(self); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "hardshrink_cuda", [&] { hardshrink_cuda_kernel(self, out_tensor, lambd.to()); }); return out_tensor; @@ -272,7 +272,7 @@ Tensor hardshrink_cuda(const Tensor & self, Scalar lambd) { Tensor hardshrink_backward_cuda(const Tensor & grad, const Tensor & self, Scalar lambd) { auto out_tensor = at::empty_like(grad); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_backward_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "hardshrink_backward_cuda", [&] { hardshrink_backward_cuda_kernel(self, out_tensor, lambd.to(), grad); }); return out_tensor; @@ -286,7 +286,7 @@ void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t va } static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar value) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "threshold", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "threshold_cuda", [&] { threshold_kernel_impl(iter, threshold.to(), value.to()); }); } diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu index 5828248..7211aa3 100644 --- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu +++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu @@ -244,7 +244,7 @@ namespace { output.resize_({sizeD, osizeH, osizeW}); } AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input_.type(), "adaptive_avg_pool2d", [&] { + input_.scalar_type(), "adaptive_avg_pool2d_cuda", [&] { scalar_t *input_data = input_.data(); scalar_t *output_data = output.data(); @@ -284,13 +284,13 @@ namespace { int64_t osizeH = gradOutput.size(-2); int64_t osizeW = gradOutput.size(-1); - + int64_t grid_x = sizeD; if (input.ndimension() == 4) grid_x *= input.size(-4); //bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "adaptive_avg_pool2d_backward", [&] { + input.scalar_type(), "adaptive_avg_pool2d_backward_cuda", [&] { scalar_t *gradOutput_data = gradOutput.data(); scalar_t *gradInput_data = gradInput.data(); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 1f0184b..291e0b5 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -260,13 +260,13 @@ std::tuple _gesv_helper_cuda(const Tensor& self, const Tensor& A auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); std::vector infos(batchCount(self), 0); - AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "gesv_cuda", [&]{ apply_gesv(self_working_copy, A_working_copy, infos); }); if (self.dim() > 2) { - batchCheckErrors(infos, "gesv"); + batchCheckErrors(infos, "gesv_cuda"); } else { - singleCheckErrors(infos[0], "gesv"); + singleCheckErrors(infos[0], "gesv_cuda"); } return std::tuple(self_working_copy, A_working_copy); } @@ -327,11 +327,11 @@ Tensor _inverse_helper_cuda(const Tensor& self) { std::vector infos(batchCount(self), 0); auto self_working_copy = cloneBatchedColumnMajor(self); auto self_inv_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "inverse", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ apply_inverse( self_working_copy, self_inv_working_copy, infos); }); - batchCheckErrors(infos, "inverse"); + batchCheckErrors(infos, "inverse_cuda"); return self_inv_working_copy; } @@ -386,7 +386,7 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp int64_t info = 0; auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); - AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky_solve", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_solve_cuda", [&]{ apply_cholesky_solve(self_working_copy, A_working_copy, upper, info); }); AT_CHECK(info == 0, "MAGMA cholesky_solve : invalid argument: ", -info); @@ -446,13 +446,13 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) { self_working_copy = cloneBatchedColumnMajor(self); } - AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_cuda", [&]{ apply_cholesky(self_working_copy, false, infos); }); if (self.dim() > 2) { - batchCheckErrors(infos, "cholesky"); + batchCheckErrors(infos, "cholesky_cuda"); } else { - singleCheckErrors(infos[0], "cholesky"); + singleCheckErrors(infos[0], "cholesky_cuda"); } if (upper) { return self_working_copy.transpose(-1, -2); @@ -493,7 +493,7 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c self_row_stride = self.stride(-2), self_col_stride = self.stride(-1); dim3 dim_block = cuda::getApplyBlock(); dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), name, [&]{ + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), name, [&]{ triu_tril_kernel <<>>( result.data(), self.data(), k, mat_size, diff --git a/aten/src/ATen/native/cuda/BinaryOpsKernel.cu b/aten/src/ATen/native/cuda/BinaryOpsKernel.cu index 5cb1212..2b8e338 100644 --- a/aten/src/ATen/native/cuda/BinaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryOpsKernel.cu @@ -22,7 +22,7 @@ void add_kernel_impl(TensorIterator& iter, Scalar alpha_scalar) { } static void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "add", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "add_cuda", [&]() { add_kernel_impl(iter, alpha_scalar); }); } @@ -46,21 +46,21 @@ void div_constant_impl(TensorIterator& iter, scalar_t inv_b) { } static void div_kernel_cuda(TensorIterator& iter) { - if (isIntegralType(iter.type().scalarType())) { - AT_DISPATCH_INTEGRAL_TYPES(iter.type(), "div", [&]() { + if (isIntegralType(iter.dtype())) { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_cuda", [&]() { div_kernel_impl(iter); }); } else if (iter.is_cpu_scalar(2)) { // optimization for floating-point types: if the second operand is a CPU // scalar, compute a * reciprocal(b). Note that this may lose one bit of // precision compared to computing the division. - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "div", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "div_cuda", [&]() { auto inv_b = scalar_t(1.0 / iter.scalar_value(2)); iter.remove_operand(2); div_constant_impl(iter, inv_b); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "div", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "div_cuda", [&]() { div_kernel_impl(iter); }); } @@ -74,7 +74,7 @@ void mul_kernel_impl(TensorIterator& iter) { } static void mul_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "mul", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "mul_cuda", [&]() { mul_kernel_impl(iter); }); } diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 70b9236..68079ad 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -10,7 +10,7 @@ namespace native { Scalar _local_scalar_dense_cuda(const Tensor& self) { Scalar r; AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, self.type(), "_local_scalar_dense_cuda", [&] { + at::ScalarType::Half, self.scalar_type(), "_local_scalar_dense_cuda", [&] { scalar_t value; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream)); diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index ab85b48..01c0782 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -169,7 +169,7 @@ void copy_from_cpu(Tensor& dst, const Tensor& src) { cudaMemcpyHostToDevice, stream)); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_from_cpu", [&]() { copy_device_to_device(dst, dst_contig); }); } @@ -202,7 +202,7 @@ void copy_from_cpu_async_(Tensor& dst, const Tensor& src) { CUDAGuard device_guard(dst.device()); CUDAStream stream = getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu_async", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_from_cpu_async", [&]() { AT_CUDA_CHECK(cudaMemcpyAsync( dst.data(), src.data(), @@ -225,7 +225,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) { CUDAGuard device_guard(src.device()); CUDAStream stream = getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_to_cpu_async", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_to_cpu_async", [&]() { AT_CUDA_CHECK(cudaMemcpyAsync( dst.data(), src.data(), @@ -240,7 +240,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) { template void _copy__cuda(Tensor& dst, const Tensor& src, bool non_blocking) { AT_CHECK(dst.numel() == src.numel(), "sizes do not match"); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_copy__cuda", [&]() { if (dst.is_cuda() && src.is_cuda()) { copy_device_to_device(dst, src); } else if (dst.is_cuda()) { @@ -279,7 +279,7 @@ namespace at { namespace native { Tensor& _s_copy__cuda(Tensor& self, const Tensor& src, bool non_blocking) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "_copy__cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "_copy__cuda", [&]() { ::_copy__cuda(self, src, non_blocking); }); return self; diff --git a/aten/src/ATen/native/cuda/DistanceKernel.cu b/aten/src/ATen/native/cuda/DistanceKernel.cu index 37d5c68..8316bde 100644 --- a/aten/src/ATen/native/cuda/DistanceKernel.cu +++ b/aten/src/ATen/native/cuda/DistanceKernel.cu @@ -188,7 +188,7 @@ void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, doubl const dim3 grid(r1*r2); const dim3 block(forward_threads); - AT_DISPATCH_FLOATING_TYPES(x1.type(), "cdist_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] { if (p == 0.0) { cdist_kernel_cuda_impl::zero><<>>(result.data(), x1.data(), x2.data(), p, r1, r2, m); } else if (p == 1.0) { @@ -213,7 +213,7 @@ void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { const double n2 = n - .5; const double n2_squared_minus_1 = n2 * n2 - 1; - AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda", [&] { if (p == 0.0) { pdist_kernel_cuda_impl::zero><<>>(result.data(), self.data(), n, m, p, n2, n2_squared_minus_1); } else if (p == 1.0) { @@ -252,7 +252,7 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor const double n2_squared_minus_1 = n2 * n2 - 1; Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options()); - AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda_backward", [&] { if (p == 1.0) { pdist_backward_kernel_cuda_impl::one><<>>(buffer.data(), grad.data(), self.data(), dist.data(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); } else if (p < 2.0) { diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu index c262384..2176547 100644 --- a/aten/src/ATen/native/cuda/Distributions.cu +++ b/aten/src/ATen/native/cuda/Distributions.cu @@ -186,7 +186,7 @@ void bernoulli_scalar_cuda_kernel( namespace at { namespace native { Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) { Tensor ret = at::empty(lambda.sizes(), lambda.options()); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "poisson", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "poisson_cuda", [&] { poisson_cuda_kernel(ret, lambda, next_philox_seed(gen, 20)); }); return ret; @@ -194,7 +194,7 @@ Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) { Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) { Tensor ret = at::empty(alpha.sizes(), alpha.options()); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "gamma", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "gamma_cuda", [&] { gamma_cuda_kernel(ret, alpha, next_philox_seed(gen, 10)); }); return ret; @@ -202,7 +202,7 @@ Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) { Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) { Tensor ret = at::empty(self.sizes(), self.options()); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "_standard_gamma_grad", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "_standard_gamma_grad_cuda", [&] { gamma_grad_cuda_kernel(ret, self, output); }); return ret; @@ -211,11 +211,11 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) { Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) { auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA))); AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, self.type(), "bernoulli_tensor_cuda_self_", [&] { + at::ScalarType::Half, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] { const at::Type& p_type = p.type(); using self_t = scalar_t; auto seeds = next_philox_seed(gen, 10); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.type(), "bernoulli_tensor_cuda_p_", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, p.scalar_type(), "bernoulli_tensor_cuda_p_", [&] { using p_t = scalar_t; return bernoulli_tensor_cuda_kernel(self, p, seeds); }); @@ -225,7 +225,7 @@ Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) { Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) { AT_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "bernoulli_scalar_cuda_", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "bernoulli_scalar_cuda_", [&] { auto seeds = next_philox_seed(gen, 10); bernoulli_scalar_cuda_kernel(self, p, seeds); }); diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index 609a691..16d0b71 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -108,7 +108,7 @@ fused_dropout_cuda(const Tensor& self, double p, Generator * gen){ //number of times random will be generated per thread, to offset philox counter in thc random state int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; if (cuda::detail::canUse32BitIndexMath(self)){ - AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "fused_dropout", [&] { using accscalar_t = acc_type; accscalar_t pa = (accscalar_t)(p); auto self_info = cuda::detail::getTensorInfo(self); @@ -126,7 +126,7 @@ fused_dropout_cuda(const Tensor& self, double p, Generator * gen){ } }); } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "fused_dropout", [&] { using accscalar_t = acc_type; accscalar_t pa = (accscalar_t)(p); auto self_info = cuda::detail::getTensorInfo(self); @@ -151,7 +151,7 @@ fused_dropout_cuda(const Tensor& self, double p, Generator * gen){ Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){ Tensor ret = at::empty_like(self); AT_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype"); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "masked_scale", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "masked_scale", [&] { using accscalar_t = acc_type; accscalar_t pa = (accscalar_t)(scale); masked_scale_kernel(ret, self, mask, pa); diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 44fe18b..8a24923 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -243,7 +243,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice dim3 block(WARP_SIZE, BLOCKDIMY); AT_DISPATCH_FLOATING_TYPES_AND_HALF - (grad.type(), + (grad.scalar_type(), "embedding_backward", [&] { @@ -326,7 +326,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice dim3 grid(THCCeilDiv(num_indices, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128)); dim3 block(32, 4); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "embedding_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "embedding_backward", [&] { embedding_backward_kernel<<>>( sorted_indices.data(), orig_indices.data(), @@ -371,7 +371,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, dim3 block(128); int dim = self.stride(0); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "embedding_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "embedding_backward", [&] { using accscalar_t = acc_type; renorm_kernel<<>>( self.data(), diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 4721611..dea987e 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -248,7 +248,7 @@ Tensor embedding_bag_backward_cuda_sum_avg( dim3 grid(THCCeilDiv(numel, (ptrdiff_t)4), THCCeilDiv(stride, (int64_t)128)); dim3 block(32, 4); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.type(), "embedding_bag_backward_cuda_sum_avg_kernel", [&] { + grad.scalar_type(), "embedding_bag_backward_cuda_sum_avg_kernel", [&] { EmbeddingBag_accGradParametersKernel_sum_avg< scalar_t><<>>( sorted_indices.data(), orig_indices.data(), @@ -304,7 +304,7 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad, int grid = 1024; AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.type(), "embedding_bag_backward_cuda_max", [&] { + grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] { EmbeddingBag_accGradParametersKernel_max< scalar_t><<>>( max_indices.data(), grad.data(), @@ -353,7 +353,7 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, dim3 block = dim3(32, 8); int grid = 1024; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(weight.type(), "embedding_bag_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(weight.scalar_type(), "embedding_bag_cuda", [&] { EmbeddingBag_updateOutputKernel<<>>( indices.data(), offsets.data(), weight.data(), output.data(), diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu index e8f6939..7a09514 100644 --- a/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu +++ b/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu @@ -195,7 +195,7 @@ void fractional_max_pool2d_out_cuda_template( input_.size(0)); dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fractional_max_pool2d_out_cuda_frame", [&] { auto devInput = input_.packed_accessor(); @@ -267,7 +267,7 @@ void fractional_max_pool2d_backward_out_cuda_template( dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); auto devIndices = indices.packed_accessor(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.type(), + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.scalar_type(), "fractional_max_pool2d_backward_out_cuda_frame", [&] { auto devGradInput = gradInput_.packed_accessor(); diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu index 79015dd..95f9b1a 100644 --- a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu @@ -231,7 +231,7 @@ void fractional_max_pool3d_out_cuda_template( dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), + input.scalar_type(), "fractional_max_pool3d_out_frame", [&]{ fractional_max_pool3d_out_frame @@ -321,7 +321,7 @@ void fractional_max_pool3d_backward_out_cuda_template( dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - gradOutput.type(), + gradOutput.scalar_type(), "fractional_max_pool3d_backward_out_frame", [&] { fractional_max_pool3d_backward_out_frame diff --git a/aten/src/ATen/native/cuda/GridSampler.cu b/aten/src/ATen/native/cuda/GridSampler.cu index 706cd44..54cf637 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cu +++ b/aten/src/ATen/native/cuda/GridSampler.cu @@ -791,7 +791,7 @@ Tensor grid_sampler_2d_cuda(const Tensor& input, const Tensor& grid, auto output = at::empty({N, input.size(1), H, W}, input.options()); int count = static_cast(N * H * W); if (count > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "grid_sampler_2d_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] { grid_sampler_2d_kernel <<>>( count, @@ -815,7 +815,7 @@ Tensor grid_sampler_3d_cuda(const Tensor& input, const Tensor& grid, auto output = at::empty({N, input.size(1), D, H, W}, input.options()); int count = static_cast(N * D * H * W); if (count > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "grid_sampler_2d_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] { grid_sampler_3d_kernel <<>>( count, @@ -840,7 +840,7 @@ grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input, co auto grad_grid = at::empty_like(grid); int count = static_cast(N * H * W); if (count > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "grid_sampler_2d_backward_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] { grid_sampler_2d_backward_kernel <<>>( count, @@ -868,7 +868,7 @@ grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input, co auto grad_grid = at::empty_like(grid); int count = static_cast(N * D * H * W); if (count > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "grid_sampler_3d_backward_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] { grid_sampler_3d_backward_kernel <<>>( count, diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index be69cde..d498179 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -81,7 +81,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra } static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cuda", [&] { using dtype = OpaqueType; index_kernel_impl(iter, index_size, index_stride); }); @@ -90,7 +90,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { AT_ASSERTM(!accumulate, "index_put does not support accumulate=true"); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index_put", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] { using dtype = OpaqueType; index_put_kernel_impl(iter, index_size, index_stride); }); diff --git a/aten/src/ATen/native/cuda/Lerp.cu b/aten/src/ATen/native/cuda/Lerp.cu index 35a46f1..1946427 100644 --- a/aten/src/ATen/native/cuda/Lerp.cu +++ b/aten/src/ATen/native/cuda/Lerp.cu @@ -38,9 +38,9 @@ Tensor& lerp_cuda_tensor_out(Tensor& result, const Tensor& self, Tensor b_self, b_end, b_weight; AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), "weight should be of dimension max(self.dim(), end.dim()) or lesser"); - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out"); + std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out_cuda"); result.resize_as_(b_self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_out", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cuda", [&]{ lerp_cuda(result, b_self, b_end, b_weight); }); return result; @@ -49,9 +49,9 @@ Tensor& lerp_cuda_tensor_out(Tensor& result, const Tensor& self, Tensor& lerp_cuda_scalar_out(Tensor& result, const Tensor& self, const Tensor& end, Scalar weight) { Tensor b_self, b_end; - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out"); + std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out_cuda"); result.resize_as_(b_self); - AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_out", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cuda", [&]{ lerp_cuda(result, b_self, b_end, weight.to()); }); return result; @@ -59,13 +59,13 @@ Tensor& lerp_cuda_scalar_out(Tensor& result, const Tensor& self, Tensor& lerp_cuda_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) { Tensor b_self, b_end, b_weight; - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_"); + std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp__cuda"); AT_CHECK(b_self.sizes() == self.sizes(), "output with shape ", self.sizes(), " doesn't match the broadcast shape ", b_self.sizes()); AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), "weight should be of dimension max(self.dim(), end.dim()) or lesser"); - AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cuda", [&]{ lerp_cuda(self, b_self, b_end, b_weight); }); return self; @@ -73,11 +73,11 @@ Tensor& lerp_cuda_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) Tensor& lerp_cuda_scalar_(Tensor& self, const Tensor& end, Scalar weight) { Tensor b_self, b_end; - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_"); + std::tie(b_self, b_end) = expand_outplace(self, end, "lerp__cuda"); AT_CHECK(b_self.sizes() == self.sizes(), "output with shape ", self.sizes(), " doesn't match the broadcast shape ", b_self.sizes()); - AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cuda", [&]{ lerp_cuda(self, b_self, b_end, weight.to()); }); return self; @@ -87,9 +87,9 @@ Tensor lerp_cuda_tensor(const Tensor& self, const Tensor& end, const Tensor& wei Tensor b_self, b_end, b_weight; AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()), "weight should be of dimension max(self.dim(), end.dim()) or lesser"); - std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp"); + std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_cuda"); Tensor result = at::empty_like(b_self); - AT_DISPATCH_FLOATING_TYPES(result.type(), "lerp", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_cuda", [&]{ lerp_cuda(result, b_self, b_end, b_weight); }); return result; @@ -97,9 +97,9 @@ Tensor lerp_cuda_tensor(const Tensor& self, const Tensor& end, const Tensor& wei Tensor lerp_cuda_scalar(const Tensor& self, const Tensor& end, Scalar weight) { Tensor b_self, b_end; - std::tie(b_self, b_end) = expand_outplace(self, end, "lerp"); + std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_cuda"); Tensor result = at::empty_like(b_self); - AT_DISPATCH_FLOATING_TYPES(result.type(), "lerp", [&]{ + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_cuda", [&]{ lerp_cuda(result, b_self, b_end, weight.to()); }); return result; diff --git a/aten/src/ATen/native/cuda/Loss.cu b/aten/src/ATen/native/cuda/Loss.cu index 50678c1..b60de0f 100644 --- a/aten/src/ATen/native/cuda/Loss.cu +++ b/aten/src/ATen/native/cuda/Loss.cu @@ -28,7 +28,7 @@ namespace at { namespace native { Tensor kl_div_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction) { auto grad_input = at::zeros_like(input); Tensor grad_expand = grad.expand_as(input); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "kl_div_backward", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "kl_div_backward_cuda", [&]() { kl_div_backward_kernel(grad_input, target, grad_expand); }); if (reduction == Reduction::Mean) { diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index 646a541..547dd6c 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -102,8 +102,7 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data, bool have_three; // flag which of the two cases in eq (6) we have if (s < 2*target_length+1) { current_char = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK); - have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != - current_char)); + have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != current_char)); } else { current_char = BLANK; have_three = false; @@ -631,7 +630,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ std::tuple ctc_loss_gpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) { (void)zero_infinity; // only used for backward - return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss", [&] { + return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_cuda", [&] { if (targets.scalar_type() == kLong) { return ctc_loss_gpu_template(log_probs, targets, input_lengths, target_lengths, BLANK); } else { @@ -642,7 +641,7 @@ std::tuple ctc_loss_gpu(const Tensor& log_probs, const Tensor& t Tensor ctc_loss_backward_gpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) { - return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss_backward", [&] { + return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_backward_cuda", [&] { if (targets.scalar_type() == kLong) { return ctc_loss_backward_gpu_template(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity); } else { diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index d77236e..8f92acd 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -4,7 +4,7 @@ namespace at { namespace native { std::tuple batch_norm_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm", [&] { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_cuda", [&] { if (cuda::detail::canUse32BitIndexMath(self)) { return batch_norm_cuda_template(self, weight, bias, running_mean, running_var, train, momentum, epsilon); } else { @@ -15,7 +15,7 @@ std::tuple batch_norm_cuda(const Tensor& self, const Ten std::tuple batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array grad_input_mask) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_cuda", [&] { if (cuda::detail::canUse32BitIndexMath(self)) { return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); } else { @@ -25,7 +25,7 @@ std::tuple batch_norm_backward_cuda(const Tensor& grad_o } std::tuple batch_norm_stats_cuda(const Tensor& self, double epsilon) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_stats", [&] { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_stats_cuda", [&] { if (cuda::detail::canUse32BitIndexMath(self)) { return batch_norm_stats_cuda_template(self, epsilon); } else { @@ -36,7 +36,7 @@ std::tuple batch_norm_stats_cuda(const Tensor& self, double epsi Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias, const Tensor& mean, const Tensor& invstd, double epsilon) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_elemt", [&] { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_elemt", [&] { if (cuda::detail::canUse32BitIndexMath(self)) { return batch_norm_elemt_cuda_template(self, weight, bias, mean, invstd, epsilon); } else { @@ -48,7 +48,7 @@ Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Ten // accepting input(self) here to determine template data types, since running_mean/running_var are optional std::tuple batch_norm_gather_stats_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean, const Tensor& running_var, double momentum, double epsilon, int64_t count) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_update_stats", [&] { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_update_stats_cuda", [&] { int world_size = mean.size(1); using accscalar_t = at::acc_type; if (cuda::detail::canUse32BitIndexMath(self)) { @@ -61,7 +61,7 @@ std::tuple batch_norm_gather_stats_cuda(const Tensor& self, cons std::tuple batch_norm_backward_reduce_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd, bool input_g, bool weight_g, bool bias_g) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward_reduce", [&] { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_reduce", [&] { if (cuda::detail::canUse32BitIndexMath(self)) { return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, input_g, weight_g, bias_g); } else { @@ -72,7 +72,7 @@ std::tuple batch_norm_backward_reduce_cuda(const Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd, const Tensor& weight, const Tensor& mean_dy, const Tensor& mean_dy_xmu) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward_elemt", [&] { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_elemt", [&] { if (cuda::detail::canUse32BitIndexMath(self)) { return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); } else { @@ -83,7 +83,7 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c std::tuple batch_norm_update_stats_cuda( const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward", [&] { auto mean_st = running_mean.dtype(); auto var_st = running_var.dtype(); AT_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); diff --git a/aten/src/ATen/native/cuda/RNN.cu b/aten/src/ATen/native/cuda/RNN.cu index a510060..67b3a7b 100644 --- a/aten/src/ATen/native/cuda/RNN.cu +++ b/aten/src/ATen/native/cuda/RNN.cu @@ -501,7 +501,7 @@ std::tuple _thnn_fused_lstm_cell_cuda( auto workspace = at::empty_like(input_gates); auto hy = at::empty_like(cx); auto cy = at::empty_like(cx); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_gates.type(), "_thnn_fused_lstm_cell_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_gates.scalar_type(), "_thnn_fused_lstm_cell_cuda", [&] { if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision] lstm_forward_impl(input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace); } else { @@ -540,7 +540,7 @@ std::tuple _thnn_fused_lstm_cell_backwar auto grad_gates = at::empty_like(workspace); auto grad_cx = at::empty_like(cx); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(workspace.type(), "_thnn_fused_lstm_cell_cuda_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(workspace.scalar_type(), "_thnn_fused_lstm_cell_cuda_backward", [&] { if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision] lstm_backward_impl(grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx); } else { @@ -565,7 +565,7 @@ std::tuple _thnn_fused_gru_cell_cuda( auto workspace = at::empty({hx.size(0), hx.size(1) * GRU_WORKSPACE_MULTIPLIER}, hx.options()); auto hy = at::empty_like(hx); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_gates.type(), "_thnn_fused_gru_cell_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_gates.scalar_type(), "_thnn_fused_gru_cell_cuda", [&] { if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision] gru_forward_impl(input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace); } else { @@ -589,7 +589,7 @@ std::tuple _thnn_fused_gru_cell_backward auto grad_input_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options()); auto grad_hidden_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options()); auto grad_hx = at::empty_like(grad_hy); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_hy.type(), "_thnn_fused_gru_cell_cuda_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_hy.scalar_type(), "_thnn_fused_gru_cell_cuda_backward", [&] { if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision] gru_backward_impl(grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx); } else { diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index 227fd81..5bd7d86 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -51,7 +51,7 @@ Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step } else if (steps == 1) { r.fill_(start); } else { - AT_DISPATCH_FLOATING_TYPES(r.type(), "linspace", [&]() { + AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "linspace_cuda", [&]() { scalar_t scalar_start = start.to(); scalar_t scalar_end = end.to(); scalar_t step = (scalar_end - scalar_start) / static_cast(steps - 1); @@ -81,7 +81,7 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step } else if (steps == 1) { r.fill_(std::pow(10.0, start.to())); } else { - AT_DISPATCH_FLOATING_TYPES(r.type(), "logspace", [&]() { + AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "logspace_cuda", [&]() { scalar_t scalar_start = start.to(); scalar_t scalar_end = end.to(); scalar_t step = (scalar_end - scalar_start) / static_cast(steps - 1); @@ -99,7 +99,7 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step } Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "range", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.scalar_type(), "range_cuda", [&]() { using accscalar_t = at::acc_type; auto xstart = start.to(); auto xend = end.to(); @@ -130,7 +130,7 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { } Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "arange", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.scalar_type(), "arange_cuda", [&]() { using accscalar_t = at::acc_type; auto xstart = start.to(); auto xend = end.to(); diff --git a/aten/src/ATen/native/cuda/ReduceOpsKernel.cu b/aten/src/ATen/native/cuda/ReduceOpsKernel.cu index 3a47902..715411d 100644 --- a/aten/src/ATen/native/cuda/ReduceOpsKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceOpsKernel.cu @@ -41,7 +41,7 @@ void prod_kernel_impl(TensorIterator& iter) { } static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "std", [&]() { std_var_kernel_impl(iter, unbiased, take_sqrt); }); } @@ -77,46 +77,46 @@ void norm_kernel_cuda_impl(TensorIterator& iter, Scalar val) { } static void sum_kernel_cuda(TensorIterator& iter) { - if (iter.type().scalarType() == kHalf) { + if (iter.dtype() == kHalf) { return sum_kernel_impl(iter); - } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) { + } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel return sum_kernel_impl(iter); } - AT_DISPATCH_ALL_TYPES(iter.type(), "sum", [&]() { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "sum_cuda", [&]() { sum_kernel_impl(iter); }); } static void prod_kernel_cuda(TensorIterator& iter) { - if (iter.type().scalarType() == kHalf) { + if (iter.dtype() == kHalf) { return prod_kernel_impl(iter); } - AT_DISPATCH_ALL_TYPES(iter.type(), "prod", [&]() { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "prod_cuda", [&]() { prod_kernel_impl(iter); }); } static void mean_kernel_cuda(TensorIterator& iter) { - if (iter.type().scalarType() == kHalf) { + if (iter.dtype() == kHalf) { return mean_kernel_impl(iter); - } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) { + } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel return mean_kernel_impl(iter); } - AT_DISPATCH_ALL_TYPES(iter.type(), "mean", [&]() { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cuda", [&]() { mean_kernel_impl(iter); }); } static void norm_kernel_cuda(TensorIterator& iter, Scalar p) { - if (iter.type().scalarType() == kHalf) { + if (iter.dtype() == kHalf) { return norm_kernel_cuda_impl(iter, p); - } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) { + } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel return norm_kernel_cuda_impl(iter, p); } - AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&]() { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cuda", [&]() { norm_kernel_cuda_impl(iter, p); }); } @@ -152,13 +152,13 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) { } void max_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.type(), "max_values", [&]() { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() { max_values_kernel_cuda_impl(iter); }); } void min_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&]() { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() { min_values_kernel_cuda_impl(iter); }); } diff --git a/aten/src/ATen/native/cuda/ReflectionPad.cu b/aten/src/ATen/native/cuda/ReflectionPad.cu index 5e6e6bf..0a1cacf 100644 --- a/aten/src/ATen/native/cuda/ReflectionPad.cu +++ b/aten/src/ATen/native/cuda/ReflectionPad.cu @@ -191,7 +191,7 @@ void reflection_pad1d_out_template( Tensor input = input_.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "reflection_pad1d_out_template", [&] { + input.scalar_type(), "reflection_pad1d_out_template", [&] { reflection_pad1d_out_kernel<<< grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( input.data(), output.data(), @@ -239,7 +239,7 @@ void reflection_pad1d_backward_out_template( dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_input.type(), "reflection_pad1d_backward_out_template", [&] { + grad_input.scalar_type(), "reflection_pad1d_backward_out_template", [&] { reflection_pad1d_backward_out_kernel<<< grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( grad_input.data(), grad_output.data(), @@ -311,7 +311,7 @@ void reflection_pad2d_out_template( (int) std::ceil(output_plane_size/256.0), nplane, nbatch); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "reflection_pad2d_out_template", [&] { + input.scalar_type(), "reflection_pad2d_out_template", [&] { reflection_pad2d_out_kernel<<< grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( input.data(), output.data(), @@ -368,7 +368,7 @@ void reflection_pad2d_backward_out_template( (int) std::ceil(output_plane_size/256.0), nplane, nbatch); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "reflection_pad2d_backward_out_template", [&] { + input.scalar_type(), "reflection_pad2d_backward_out_template", [&] { reflection_pad2d_backward_out_kernel<<< grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( grad_input.data(), grad_output.data(), diff --git a/aten/src/ATen/native/cuda/ReplicationPadding.cu b/aten/src/ATen/native/cuda/ReplicationPadding.cu index a9790df..867ebf2 100644 --- a/aten/src/ATen/native/cuda/ReplicationPadding.cu +++ b/aten/src/ATen/native/cuda/ReplicationPadding.cu @@ -235,7 +235,7 @@ void replication_pad1d_out_cuda_template( AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "replication_pad1d", [&] { + input.scalar_type(), "replication_pad1d_cuda", [&] { if (numInputDims == 2) { @@ -306,7 +306,7 @@ void replication_pad1d_backward_out_cuda_template( gradInput.zero_(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "replication_pad1d_backward", [&] { + input.scalar_type(), "replication_pad1d_backward_cuda", [&] { auto gradInput_ = gradInput; auto gradOutput_ = gradOutput; @@ -372,7 +372,7 @@ void replication_pad2d_out_cuda_template( " Calculated output H: ", outputH, " W: ", outputW); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "replication_pad2d", [&] { + input.scalar_type(), "replication_pad2d_cuda", [&] { if (numInputDims == 3) { @@ -403,7 +403,7 @@ void replication_pad2d_out_cuda_template( dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); replication_pad_forward_kernel2d <<>>(devInput, devOutput, + at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padT, padB, padL, padR); } } @@ -427,7 +427,7 @@ void replication_pad2d_backward_out_cuda_template( int padL = paddingSize[0]; int padR = paddingSize[1]; int padT = paddingSize[2]; - int padB = paddingSize[3]; + int padB = paddingSize[3]; int planeDim = 0; int dimh = 1; int dimw = 2; @@ -454,7 +454,7 @@ void replication_pad2d_backward_out_cuda_template( gradInput.zero_(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "replication_pad2d_backward", [&] { + input.scalar_type(), "replication_pad2d_backward_cuda", [&] { auto gradInput_ = gradInput; auto gradOutput_ = gradOutput; @@ -483,7 +483,7 @@ static inline void shapeCheck3d( int pleft, int pright, int ptop, int pbottom, int pfront, int pback) { - AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), + AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); int numInputDims = input.dim(); @@ -521,7 +521,7 @@ static inline void shapeAndGradOutputCheck3d( int pleft, int pright, int ptop, int pbottom, int pfront, int pback) { - AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), + AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); int numInputDims = input.dim(); @@ -579,7 +579,7 @@ void replication_pad3d_out_cuda_template( int ptop = paddingSize[2]; int pbottom = paddingSize[3]; int pfront = paddingSize[4]; - int pback = paddingSize[5]; + int pback = paddingSize[5]; shapeCheck3d(input, pleft, pright, ptop, pbottom, pfront, pback); @@ -608,7 +608,7 @@ void replication_pad3d_out_cuda_template( int outputW = inputW + pleft + pright; AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "replication_pad3d", [&] { + input.scalar_type(), "replication_pad3d_cuda", [&] { if (numInputDims == 4) { output.resize_({numPlanes, outputD, outputH, outputW}); @@ -660,7 +660,7 @@ void replication_pad3d_backward_out_cuda_template( int ptop = paddingSize[2]; int pbottom = paddingSize[3]; int pfront = paddingSize[4]; - int pback = paddingSize[5]; + int pback = paddingSize[5]; shapeAndGradOutputCheck3d(input, gradOutput, pleft, pright, ptop, pbottom, pfront, pback); @@ -681,7 +681,7 @@ void replication_pad3d_backward_out_cuda_template( gradInput.zero_(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "replication_pad3d_backward", [&] { + input.scalar_type(), "replication_pad3d_backward_cuda", [&] { auto gradInput_ = gradInput; auto gradOutput_ = gradOutput; diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 8808272..ef3031e 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -499,7 +499,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t const int ILP = 2; dim3 grid(outer_size); dim3 block = SoftMax_getBlockSize(ILP, dim_size); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "host_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "host_softmax", [&] { using accscalar_t = acc_type; if (!half_to_float) { cunn_SoftMaxForward @@ -519,7 +519,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } else { uint32_t smem_size; dim3 grid, block; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "host_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "host_softmax", [&] { using accscalar_t = acc_type; if (!half_to_float) { SpatialSoftMax_getLaunchSizes( @@ -573,7 +573,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t const int ILP = 2; dim3 grid(outer_size); dim3 block = SoftMax_getBlockSize(ILP, dim_size); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(gI.type(), "host_softmax_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gI.scalar_type(), "host_softmax_backward", [&] { using accscalar_t = acc_type; if (!half_to_float) { cunn_SoftMaxBackward @@ -590,7 +590,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t } else { uint32_t smem_size; dim3 grid, block; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "host_softmax_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "host_softmax_backward", [&] { using accscalar_t = acc_type; if (!half_to_float) { SpatialSoftMax_getLaunchSizes( diff --git a/aten/src/ATen/native/cuda/SortingKthValue.cu b/aten/src/ATen/native/cuda/SortingKthValue.cu index 4018d1d..50d0b90 100644 --- a/aten/src/ATen/native/cuda/SortingKthValue.cu +++ b/aten/src/ATen/native/cuda/SortingKthValue.cu @@ -233,14 +233,14 @@ std::tuple kthvalue_out_cuda( int64_t k, int64_t dim, bool keepdim) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "kthvalue", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] { kthvalue_cuda_template(values, indices, self, k, dim, keepdim); }); return std::forward_as_tuple(values, indices); } Tensor median_cuda(const Tensor& self) { - return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "median", [&] { + return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "median", [&] { return median_cuda_template(self); }); } diff --git a/aten/src/ATen/native/cuda/SpectralOps.cu b/aten/src/ATen/native/cuda/SpectralOps.cu index 6a79f66..19c28ff 100644 --- a/aten/src/ATen/native/cuda/SpectralOps.cu +++ b/aten/src/ATen/native/cuda/SpectralOps.cu @@ -114,7 +114,7 @@ static void _fft_fill_with_conjugate_symmetry_(Tensor& input, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); auto policy = thrust::cuda::par(allocator).on(stream); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "_fft_fill_with_conjugate_symmetry_", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "_fft_fill_with_conjugate_symmetry_", [&] { typedef thrust::device_ptr device_ptr; typedef thrust::counting_iterator counter; typedef thrust::transform_iterator dst_idx_iterator; diff --git a/aten/src/ATen/native/cuda/SummaryOps.cu b/aten/src/ATen/native/cuda/SummaryOps.cu index 3903d9b..d2c7be5 100644 --- a/aten/src/ATen/native/cuda/SummaryOps.cu +++ b/aten/src/ATen/native/cuda/SummaryOps.cu @@ -330,7 +330,7 @@ Tensor _bincount_cuda( const Tensor& self, const Tensor& weights, int64_t minlength) { - return AT_DISPATCH_INTEGRAL_TYPES(self.type(), "bincount", [&] { + return AT_DISPATCH_INTEGRAL_TYPES(self.scalar_type(), "bincount_cuda", [&] { const auto scalar = weights.scalar_type(); if (scalar == ScalarType::Undefined || scalar == ScalarType::Float) return _bincount_cuda_template(self, weights, minlength); @@ -347,7 +347,7 @@ Tensor _histc_cuda( if (self.scalar_type() == ScalarType::Half) { AT_ERROR("HalfTensor is not supported"); } - return AT_DISPATCH_ALL_TYPES(self.type(), "histc", [&] { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "histc", [&] { return _histc_cuda_template(self, nbins, min.to(), max.to()); }); } diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index f7d9b68..d4ccc70 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -33,7 +33,7 @@ Tensor _s_where_cuda( const Tensor& self, const Tensor& other) { Tensor ret = at::empty(self.sizes(), self.options()); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.type(), "where", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.scalar_type(), "where_cuda", [&] { where_cuda(ret, condition, self, other); }); return ret; diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index f0c32cc..c9bb377 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -90,7 +90,7 @@ Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) { } else { // Generate random values for the keys array AT_DISPATCH_ALL_TYPES( - result.type(), "randperm_out_cuda", [&] { + result.scalar_type(), "randperm_out_cuda", [&] { auto keys = at::empty(result.sizes(), result.options()).random_(generator); auto result_data = thrust::device_ptr(result.data()); @@ -322,7 +322,7 @@ Tensor tril_indices_cuda( cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()), "unable to get dim grid"); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "tril_indices_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.scalar_type(), "tril_indices_cuda", [&] { tril_indices_kernel<<< dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>( tensor.data(), @@ -398,7 +398,7 @@ Tensor triu_indices_cuda( cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()), "unable to get dim grid"); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "triu_indices_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.scalar_type(), "triu_indices_cuda", [&] { triu_indices_kernel<<< dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>( tensor.data(), diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index 5fccc10..00f50f1 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -87,7 +87,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { // use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.scalar_type(), "flip_cuda", [&] { auto in_tensor_info = cuda::detail::getTensorInfo(in_tensor); auto out_tensor_info = cuda::detail::getTensorInfo(out_tensor); int flip_dim = in_tensor_info.collapseDims(flip_dims[0]); @@ -119,7 +119,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { } } - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.scalar_type(), "flip_cuda", [&] { flip_cuda_kernel<<>>( in_tensor.data(), out_tensor.data(), N, flip_dims_t.toType(CUDA(kLong)).data(), flip_dims_size, strides_t.toType(CUDA(kLong)).data(), stride_contiguous.toType(CUDA(kLong)).data(), shape_t.toType(CUDA(kLong)).data(), total_dims); @@ -177,7 +177,7 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { auto total_dims = in_tensor.dim(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "roll_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.scalar_type(), "roll_cuda", [&] { roll_cuda_kernel<<>>( in_tensor.data(), out_tensor.data(), N, dim, start, diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index 828fb48..0ba6812 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -146,7 +146,7 @@ template std::tuple _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { - return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] { // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust return _unique_cuda_template(self, return_inverse); @@ -155,7 +155,7 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { std::tuple _unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) { - return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] { return _unique_dim_cuda_template(self, dim, return_inverse); }); } diff --git a/aten/src/ATen/native/cuda/WeightNorm.cu b/aten/src/ATen/native/cuda/WeightNorm.cu index f8e9c14..76f4272 100644 --- a/aten/src/ATen/native/cuda/WeightNorm.cu +++ b/aten/src/ATen/native/cuda/WeightNorm.cu @@ -7,7 +7,7 @@ #include #include -namespace at { +namespace at { namespace native { namespace { @@ -15,15 +15,15 @@ namespace { // Currently, kernels are non-persistent. // Dialing up the block size to, say 1024, can improve performance by // increase the amount of cache available per block, which can improve cache hit rate. -// However, this is less efficient for short rows. 256 is pretty versatile. +// However, this is less efficient for short rows. 256 is pretty versatile. // May be worth implementing heuristics later. #define BLOCK 256 // Block size for weight_norm_*_last_dim_kernel. -// This is tricker than the first_dim case because we must make blocks +// This is tricker than the first_dim case because we must make blocks // at least 16 fast elements wide to ensure fully-coalesced half-precision accesses. -// Since output-element parallelism is along the fast dimension, this reduces the number of -// blocks we can launch by 16X. +// Since output-element parallelism is along the fast dimension, this reduces the number of +// blocks we can launch by 16X. #define TILE_W 16 // Somewhat versatile strategy: max out intra-block parallelism by extending // blocks across the slow dimension up to the hardware-max block size of 1024. @@ -31,11 +31,11 @@ namespace { template __device__ __forceinline__ void reduce_block_into_lanes - (T *x, - T val, + (T *x, + T val, int lanes, // lanes is intended to be <= 32. - ReduceOp reduceOp) -{ + ReduceOp reduceOp) +{ int tid = threadIdx.x + threadIdx.y*blockDim.x; int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. @@ -44,16 +44,16 @@ __device__ __forceinline__ void reduce_block_into_lanes x[tid] = val; __syncthreads(); } - + #pragma unroll - for(int i = (blockSize >> 1); i >= 64; i >>= 1) + for(int i = (blockSize >> 1); i >= 64; i >>= 1) { if(tid < i) x[tid] = reduceOp(x[tid], x[tid+i]); __syncthreads(); } - if(tid < 32) + if(tid < 32) { T final; if(blockSize >= 64) @@ -66,7 +66,7 @@ __device__ __forceinline__ void reduce_block_into_lanes for(int i = 16; i >= lanes; i >>= 1) final = reduceOp(final, WARP_SHFL_DOWN(final, i)); - if(tid < lanes) + if(tid < lanes) x[tid] = final; // EpilogueOp } @@ -75,14 +75,14 @@ __device__ __forceinline__ void reduce_block_into_lanes } template - __global__ void weight_norm_fwd_first_dim_kernel (scalar_t* __restrict__ w, accscalar_t* __restrict__ norms, const scalar_t* __restrict__ v, const scalar_t* __restrict__ g, - const int rowSize) + const int rowSize) { // We are norming each slowest-dim row of the tensor separately. // For now, assign one block to each row. @@ -98,11 +98,11 @@ __global__ void weight_norm_fwd_first_dim_kernel // extern __shared__ accscalar_t s[]; // error: declaration is incompatible with previous "s" extern __shared__ char buf[]; accscalar_t* s = (accscalar_t*)buf; - + accscalar_t thread_sum = 0.f; - for(int i = tid; i < rowSize; i += stride ) + for(int i = tid; i < rowSize; i += stride ) { - accscalar_t val_f = scalar_cast(v[i+rowStart]); + accscalar_t val_f = scalar_cast(v[i+rowStart]); thread_sum += val_f*val_f; // AccumOp, could do Kahan here } @@ -110,7 +110,7 @@ __global__ void weight_norm_fwd_first_dim_kernel accscalar_t result = s[0]; result = sqrtf(result); - + if(tid == 0) norms[row] = result; @@ -120,7 +120,7 @@ __global__ void weight_norm_fwd_first_dim_kernel accscalar_t rnorm = 1.f/result; // for consistency with backward kernel // Write data to output - for(int i = tid; i < rowSize; i += stride ) + for(int i = tid; i < rowSize; i += stride ) { accscalar_t val_f = scalar_cast(v[i+rowStart]); w[i+rowStart] = scalar_cast(g_this_row*val_f*rnorm); @@ -128,7 +128,7 @@ __global__ void weight_norm_fwd_first_dim_kernel } template - __global__ void weight_norm_fwd_last_dim_kernel ( @@ -154,13 +154,13 @@ __global__ void weight_norm_fwd_last_dim_kernel if(fast_dim_location < fast_dim_size) while(slower_dims_location < slower_dims_size) { - accscalar_t val_f = scalar_cast(v[currentIdx]); + accscalar_t val_f = scalar_cast(v[currentIdx]); thread_sum += val_f*val_f; // AccumOp, could do Kahan here currentIdx += blockDim.y*fast_dim_size; - slower_dims_location += blockDim.y; + slower_dims_location += blockDim.y; } - reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd()); + reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd()); // Better to pass an EpilogueOp to reduce_block_into_lanes? if(threadIdx.y == 0) @@ -170,26 +170,26 @@ __global__ void weight_norm_fwd_last_dim_kernel norms[fast_dim_location] = norm_this_col; rnorms_this_block[threadIdx.x] = 1.f/norm_this_col; } - - __syncthreads(); - accscalar_t g_this_col = scalar_cast(g[fast_dim_location]); - accscalar_t rnorm = rnorms_this_block[threadIdx.x]; + __syncthreads(); + + accscalar_t g_this_col = scalar_cast(g[fast_dim_location]); + accscalar_t rnorm = rnorms_this_block[threadIdx.x]; slower_dims_location = threadIdx.y; currentIdx = fast_dim_location + fast_dim_size*slower_dims_location; if(fast_dim_location < fast_dim_size) while(slower_dims_location < slower_dims_size) { - accscalar_t val_f = scalar_cast(v[currentIdx]); + accscalar_t val_f = scalar_cast(v[currentIdx]); w[currentIdx] = scalar_cast(g_this_col*val_f*rnorm); currentIdx += blockDim.y*fast_dim_size; - slower_dims_location += blockDim.y; - } + slower_dims_location += blockDim.y; + } } template - __global__ void weight_norm_bwd_first_dim_kernel (scalar_t* __restrict__ grad_v, @@ -213,12 +213,12 @@ __global__ void weight_norm_bwd_first_dim_kernel // extern __shared__ accscalar_t s[]; // error: declaration is incompatible with previous "s" extern __shared__ char buf[]; accscalar_t* s = (accscalar_t*)buf; - + accscalar_t thread_sum = 0.f; - for(int i = tid; i < rowSize; i += stride ) + for(int i = tid; i < rowSize; i += stride ) { - accscalar_t grad_wi = scalar_cast(grad_w[i+rowStart]); - accscalar_t saved_vi = scalar_cast(saved_v[i+rowStart]); + accscalar_t grad_wi = scalar_cast(grad_w[i+rowStart]); + accscalar_t saved_vi = scalar_cast(saved_v[i+rowStart]); thread_sum += grad_wi*saved_vi; // AccumOp, could do Kahan here } @@ -228,7 +228,7 @@ __global__ void weight_norm_bwd_first_dim_kernel // Could choose to save reciprocal of norm instead I suppose, but norms is probably // more handy to keep around. // Broadcast load; could use shared memory instead. - accscalar_t rnorm = 1.f/saved_norms[row]; + accscalar_t rnorm = 1.f/saved_norms[row]; accscalar_t rnorm3 = rnorm*rnorm*rnorm; // Write g gradients. @@ -237,20 +237,20 @@ __global__ void weight_norm_bwd_first_dim_kernel // Broadcast load, could use shared memory instead. accscalar_t g_this_row = scalar_cast(saved_g[row]); - - // Write v gradients. We are reusing values that were loaded earlier, so there + + // Write v gradients. We are reusing values that were loaded earlier, so there // is an optimization opportunity here (store values persistently). - for(int j = tid; j < rowSize; j += stride ) + for(int j = tid; j < rowSize; j += stride ) { - accscalar_t grad_wj = scalar_cast(grad_w[j+rowStart]); - accscalar_t saved_vj = scalar_cast(saved_v[j+rowStart]); + accscalar_t grad_wj = scalar_cast(grad_w[j+rowStart]); + accscalar_t saved_vj = scalar_cast(saved_v[j+rowStart]); accscalar_t grad_vj = g_this_row*(rnorm*grad_wj - rnorm3*saved_vj*result); grad_v[j+rowStart] = scalar_cast(grad_vj); } } -template - __global__ void weight_norm_bwd_last_dim_kernel (scalar_t* __restrict__ grad_v, @@ -274,18 +274,18 @@ __global__ void weight_norm_bwd_last_dim_kernel if(fast_dim_location < fast_dim_size) while(slower_dims_location < slower_dims_size) { - accscalar_t grad_wi = scalar_cast(grad_w[currentIdx]); - accscalar_t saved_vi = scalar_cast(saved_v[currentIdx]); + accscalar_t grad_wi = scalar_cast(grad_w[currentIdx]); + accscalar_t saved_vi = scalar_cast(saved_v[currentIdx]); thread_sum += grad_wi*saved_vi; // AccumOp, could do Kahan here currentIdx += blockDim.y*fast_dim_size; - slower_dims_location += blockDim.y; + slower_dims_location += blockDim.y; } - reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd()); + reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd()); accscalar_t result = s[threadIdx.x]; // Broadcast load; could use shared memory instead. - accscalar_t rnorm = 1.f/saved_norms[fast_dim_location]; + accscalar_t rnorm = 1.f/saved_norms[fast_dim_location]; accscalar_t rnorm3 = rnorm*rnorm*rnorm; // Write g gradients. @@ -301,13 +301,13 @@ __global__ void weight_norm_bwd_last_dim_kernel if(fast_dim_location < fast_dim_size) while(slower_dims_location < slower_dims_size) { - accscalar_t grad_wj = scalar_cast(grad_w[currentIdx]); - accscalar_t saved_vj = scalar_cast(saved_v[currentIdx]); + accscalar_t grad_wj = scalar_cast(grad_w[currentIdx]); + accscalar_t saved_vj = scalar_cast(saved_v[currentIdx]); accscalar_t grad_vj = g_this_col*(rnorm*grad_wj - rnorm3*saved_vj*result); grad_v[currentIdx] = scalar_cast(grad_vj); currentIdx += blockDim.y*fast_dim_size; - slower_dims_location += blockDim.y; - } + slower_dims_location += blockDim.y; + } } } // anonymous namespace @@ -315,7 +315,7 @@ __global__ void weight_norm_bwd_last_dim_kernel std::tuple weight_norm_cuda (const Tensor & v, const Tensor & g, - int64_t dim) + int64_t dim) { auto w = at::empty_like(v); @@ -323,17 +323,17 @@ std::tuple weight_norm_cuda // sends the unpacked g.data() as the argument. In other words, we expect "g" is a bare Tensor here. // norms is only needed to stash for backward. - // g.scalar_type() may be at::ScalarType::Double, Float, or Half. + // g.scalar_type() may be at::ScalarType::Double, Float, or Half. // If Half, stash norms as float. at::ScalarType AccType = g.scalar_type() == at::ScalarType::Half ? at::ScalarType::Float : g.scalar_type(); - // Will this create norms on the same device as g, regardless of what the thread's default + // Will this create norms on the same device as g, regardless of what the thread's default // current device is? I believe so, because Type::* functions are DeviceGuard()ed. auto norms = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(AccType)); const int ndims = v.dim(); - if(dim == 0) + if(dim == 0) { // Find logical size of each flattened slowest-dim row int rowSize = 1; @@ -343,21 +343,21 @@ std::tuple weight_norm_cuda cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF - (v.type(), - "weight_norm_fwd_first_dim_kernel", + (v.scalar_type(), + "weight_norm_fwd_first_dim_kernel", [&] { using accscalar_t = acc_type; weight_norm_fwd_first_dim_kernel - <<>> - (w.data(), + (w.data(), norms.data(), - v.data(), - g.data(), + v.data(), + g.data(), rowSize); }); } @@ -369,16 +369,16 @@ std::tuple weight_norm_cuda slower_dims_size *= v.size(i); int fast_dim_size = v.size(ndims-1); - + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF - (v.type(), - "weight_norm_fwd_last_dim_kernel", + (v.scalar_type(), + "weight_norm_fwd_last_dim_kernel", [&] { using accscalar_t = acc_type; - + weight_norm_fwd_last_dim_kernel <<<(fast_dim_size+TILE_W-1)/TILE_W, dim3(TILE_W,TILE_H), @@ -395,7 +395,7 @@ std::tuple weight_norm_cuda // The kernel execution is asynchronous, so this will only catch errors on the kernel launch, // not the kernel's execution. Errors in kernel execution aren't guaranteed to be caught - // until a later error check on a synchronizing CUDA call. Unfortunately, without manually + // until a later error check on a synchronizing CUDA call. Unfortunately, without manually // synchronizing here, this is the best we can do. THCudaCheck(cudaGetLastError()); @@ -403,9 +403,9 @@ std::tuple weight_norm_cuda } std::tuple weight_norm_cuda_backward - (const Tensor & grad_w, - const Tensor & saved_v, - const Tensor & saved_g, + (const Tensor & grad_w, + const Tensor & saved_v, + const Tensor & saved_g, const Tensor & saved_norms, int64_t dim) { @@ -421,7 +421,7 @@ std::tuple weight_norm_cuda_backward const int ndims = saved_v.dim(); - if(dim == 0) + if(dim == 0) { // Find logical size of each flattened slowest-dim row int rowSize = 1; @@ -431,15 +431,15 @@ std::tuple weight_norm_cuda_backward cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF - (saved_v.type(), - "weight_norm_bwd_first_dim_kernel", + (saved_v.scalar_type(), + "weight_norm_bwd_first_dim_kernel", [&] { using accscalar_t = acc_type; weight_norm_bwd_first_dim_kernel - <<>> (grad_v.data(), @@ -463,15 +463,15 @@ std::tuple weight_norm_cuda_backward cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF - (saved_v.type(), - "weight_norm_bwd_last_dim_kernel", + (saved_v.scalar_type(), + "weight_norm_bwd_last_dim_kernel", [&] { using accscalar_t = acc_type; weight_norm_bwd_last_dim_kernel <<<(fast_dim_size+TILE_W-1)/TILE_W, - dim3(TILE_W,TILE_H), + dim3(TILE_W,TILE_H), (TILE_W*TILE_H + TILE_W)*sizeof(accscalar_t), stream>>> (grad_v.data(), @@ -487,7 +487,7 @@ std::tuple weight_norm_cuda_backward // The kernel execution is asynchronous, so this will only catch errors on the kernel launch, // not the kernel's execution. Errors in kernel execution aren't guaranteed to be caught - // until a later error check on a synchronizing CUDA call. Unfortunately, without manually + // until a later error check on a synchronizing CUDA call. Unfortunately, without manually // synchronizing here, this is the best we can do. THCudaCheck(cudaGetLastError()); diff --git a/aten/src/ATen/native/mkl/LinearAlgebra.cpp b/aten/src/ATen/native/mkl/LinearAlgebra.cpp index 97a641c..809bd82 100644 --- a/aten/src/ATen/native/mkl/LinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/LinearAlgebra.cpp @@ -83,7 +83,7 @@ static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, c Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { // checks are done in native/LinearAlgebra.cpp - AT_DISPATCH_FLOATING_TYPES(self.type(), "baddbmm__mkl", [&] { + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "baddbmm__mkl", [&] { baddbmm_mkl_template(self, batch1, batch2, beta, alpha); }); diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 31f8697..6ffe4cb 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -145,7 +145,7 @@ static inline void _fft_fill_with_conjugate_symmetry_(Tensor& input, { int tid = omp_get_thread_num(); int64_t start = tid * num_slices_per_thread; - AT_DISPATCH_FLOATING_TYPES(input.type(), "_fft_fill_with_conjugate_symmetry", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "_fft_fill_with_conjugate_symmetry", [&] { _fft_fill_with_conjugate_symmetry_slice(input, signal_ndim, size_last_dim, last_dim_start_slice, start, std::min(num_slices_per_thread, num - start)); }); @@ -153,7 +153,7 @@ static inline void _fft_fill_with_conjugate_symmetry_(Tensor& input, return; } #endif - AT_DISPATCH_FLOATING_TYPES(input.type(), "_fft_fill_with_conjugate_symmetry", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "_fft_fill_with_conjugate_symmetry", [&] { _fft_fill_with_conjugate_symmetry_slice(input, signal_ndim, size_last_dim, last_dim_start_slice, 0, num); }); @@ -291,4 +291,3 @@ Tensor _fft_mkl(const Tensor& self, int64_t signal_ndim, }} // namespace at::native #endif - diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index f89dce3..d3278c6 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -287,7 +287,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ // TODO: it seems like sparse_dim == 0 could be supported even if self.dim() > 0, // but this would take some work and doesn't seem particularly useful. AT_CHECK(sparse_dim > 0 || self.dim() == 0, "sparse_dim must be >0 if dimensionality > 0"); - AT_CHECK(sparse_dim <= dims, + AT_CHECK(sparse_dim <= dims, "sparse_dim must be less than or equal to self.dim()"); at::TensorOptions sparse_options = self.options().layout(kSparse); std::vector sizes = self.sizes().vec(); @@ -376,7 +376,7 @@ SparseTensor coalesce_sparse_cpu(const SparseTensor& self) { int64_t i = -1; AT_DISPATCH_ALL_TYPES( - values.type(), "coalesce", [&] { + values.scalar_type(), "coalesce", [&] { int64_t prev = -1; int64_t blockSize = values.stride(0); scalar_t* values_ptr = values.data(); @@ -483,7 +483,7 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse // TODO: Re-audit this; it used to be an indexSelect directly into r_values at::index_select_out(r_values, t_view, 0, indices); } else { - AT_DISPATCH_ALL_TYPES(r_values.type(), "sparse_mask", [&] { + AT_DISPATCH_ALL_TYPES(r_values.scalar_type(), "sparse_mask", [&] { sparse_mask_out_cpu_kernel( r_values, t, diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 7548961..6e31b23 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -226,7 +226,7 @@ SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const S auto src_indices_accessor = src_indices.accessor(); AT_DISPATCH_ALL_TYPES( - t_values.type(), "cadd_sparse", [&] { + t_values.scalar_type(), "cadd_sparse", [&] { scalar_t* t_values_ptr = t_values.data(); scalar_t* s_values_ptr = s_values.data(); scalar_t* r_values_ptr = r_values.data(); @@ -347,7 +347,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, SparseTensorRef } } else { AT_DISPATCH_ALL_TYPES( - values.type(), "add_dense_sparse", [&] { + values.scalar_type(), "add_dense_sparse", [&] { add_dense_sparse_worker_cpu(r, value, sparse, indices, values); }); } @@ -435,7 +435,7 @@ SparseTensor& mul_out_sparse_cpu(SparseTensor& r, const Tensor& t_, const Tensor } } else { AT_DISPATCH_ALL_TYPES( - r_values.type(), "mul_out_sparse", [&] { + r_values.scalar_type(), "mul_out_sparse", [&] { auto r_accessor = r_values.accessor(); auto t_accessor = t_values.accessor(); auto s_accessor = s_values.accessor(); @@ -551,7 +551,7 @@ Tensor& s_addmm_out_sparse_dense_cpu( Tensor values = sparse_._values(); AT_DISPATCH_ALL_TYPES( - values.type(), "addmm_sparse_dense", [&] { + values.scalar_type(), "addmm_sparse_dense", [&] { s_addmm_out_sparse_dense_worker(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense); } ); @@ -757,7 +757,7 @@ SparseTensor& _sspaddmm_out_cpu( int64_t newv_stride0 = newv.stride(0); AT_DISPATCH_ALL_TYPES( - values.type(), "sspmm", [&] { + values.scalar_type(), "sspmm", [&] { auto values_accessor = values.accessor(); scalar_t* dense_ptr = dense.data(); scalar_t* newv_ptr = newv.data(); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index fc144d6..3d48a63 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -95,7 +95,7 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { dim3 grid(THCCeilDiv(newNnz, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128)); dim3 block(32, 4); AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half,values.type(), "coalesce_sparse_cuda", [&] { + at::ScalarType::Half,values.scalar_type(), "coalesce_sparse_cuda", [&] { using cuda_accscalar_t = acc_type; apply::coalesceValuesKernel<<>>( uniqueOffsets.data(), diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index cac6835..b0f5657 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -91,7 +91,7 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT // No half support, so we don't have to use CUDATypeConversion Tensor r__; AT_DISPATCH_FLOATING_TYPES( - values.type(), "addmm_sparse_cuda", [&] { + values.scalar_type(), "addmm_sparse_cuda", [&] { scalar_t cast_beta = beta.to(); scalar_t cast_alpha = alpha.to(); if (cast_beta == 0) { @@ -296,7 +296,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR AT_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] { + at::ScalarType::Half, values.scalar_type(), "add_out_dense_sparse_cuda", [&] { apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> <<>>( TensorCAddOp(value.to()), @@ -310,7 +310,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR values = values.contiguous(); AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] { + at::ScalarType::Half, values.scalar_type(), "add_out_dense_sparse_cuda", [&] { apply::sparseElementwiseKernel, uint64_t, scalar_t> <<>>( TensorCAddOp(value.to()), @@ -324,7 +324,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR // FIXME: at some point we can wrap the scale into indexAdd // NB: Purposely not inplace! AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] { + at::ScalarType::Half, values.scalar_type(), "add_out_dense_sparse_cuda", [&] { if (value.to() != static_cast(1)) { values = values.mul(value); } @@ -379,7 +379,7 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const Tensor s_values_ = src._values(); AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, s_values_.type(), "add_out_sparse_cuda", [&] { + at::ScalarType::Half, s_values_.scalar_type(), "add_out_sparse_cuda", [&] { if (value.to() != static_cast(1)) { s_values_ = s_values_.mul(value); } @@ -449,7 +449,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons LongTensor resultNnz = at::empty({1}, CUDA(kLong)); AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, t_values_.type(), "mul_out_sparse_cuda", [&] { + at::ScalarType::Half, t_values_.scalar_type(), "mul_out_sparse_cuda", [&] { apply::valueSparseIntersectionKernel, uint64_t, scalar_t> <<>>( TensorMulOp(), @@ -620,7 +620,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_ auto input_indices_ti = getTensorInfo(input_indices_1D); auto input_indices_pos_ti = getTensorInfo(input_indices_pos); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_values.type(), "_sparse_sum_backward_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_values.scalar_type(), "_sparse_sum_backward_cuda", [&] { auto grad_values_expand_ti = getTensorInfo(grad_values_expand); auto grad_input_values_ti = getTensorInfo(grad_input_values); diff --git a/aten/src/ATen/test/apply_utils_test.cpp b/aten/src/ATen/test/apply_utils_test.cpp index 69f1776..cc97c03 100644 --- a/aten/src/ATen/test/apply_utils_test.cpp +++ b/aten/src/ATen/test/apply_utils_test.cpp @@ -27,7 +27,7 @@ void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) { auto zero_dim = at::empty({}, type); zero_dim.fill_(2); zero_dim.exp_(); - AT_DISPATCH_FLOATING_TYPES(zero_dim.type(), "test0", [&] { + AT_DISPATCH_FLOATING_TYPES(zero_dim.scalar_type(), "test0", [&] { ASSERT(zero_dim.data()[0] == std::exp(2)); }); @@ -50,7 +50,7 @@ void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) { } } - AT_DISPATCH_FLOATING_TYPES(a0.type(), "test1", [&] { + AT_DISPATCH_FLOATING_TYPES(a0.scalar_type(), "test1", [&] { CPU_tensor_apply2( a0, a1, [](scalar_t& y, const scalar_t& x) { y = x * x; }); CPU_tensor_apply2( @@ -62,7 +62,7 @@ void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) { } }); - AT_DISPATCH_FLOATING_TYPES(a0.type(), "test2", [&] { + AT_DISPATCH_FLOATING_TYPES(a0.scalar_type(), "test2", [&] { CPU_tensor_apply3( a0, a1, a2, [](scalar_t& y, const scalar_t& x, const scalar_t& z) { y = x * x + z; @@ -79,7 +79,7 @@ void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) { } }); - AT_DISPATCH_FLOATING_TYPES(a0.type(), "test3", [&] { + AT_DISPATCH_FLOATING_TYPES(a0.scalar_type(), "test3", [&] { CPU_tensor_apply4( a0, a1, diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index 3ca0d3c..24c4da8 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -101,7 +101,7 @@ TEST(TestScalar, TestScalar) { ASSERT_EQ(scalar_to_tensor(ones({}).item()).scalar_type(), kDouble); if (x.scalar_type() != ScalarType::Half) { - AT_DISPATCH_ALL_TYPES(x.type(), "foo", [&] { + AT_DISPATCH_ALL_TYPES(x.scalar_type(), "foo", [&] { scalar_t s = 1; std::stringstream ss; ASSERT_NO_THROW( diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py index e238071..e07b823 100755 --- a/test/test_cpp_extensions.py +++ b/test/test_cpp_extensions.py @@ -338,7 +338,7 @@ class TestCppExtension(common.TestCase): torch::Tensor half_test(torch::Tensor input) { auto output = torch::empty(1, input.options().dtype(torch::kFloat)); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "half_test", [&] { half_test_kernel<<<1, 1>>>( input.data(), output.data()); diff --git a/torch/csrc/TypeInfo.cpp b/torch/csrc/TypeInfo.cpp index 8b853e5..d90bdfd 100644 --- a/torch/csrc/TypeInfo.cpp +++ b/torch/csrc/TypeInfo.cpp @@ -119,7 +119,7 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) { static PyObject* THPFInfo_eps(THPFInfo* self, void*) { return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( - at::CPU(self->type), "epsilon", [] { + self->type, "epsilon", [] { return PyFloat_FromDouble( std::numeric_limits< at::scalar_value_type::type>::epsilon()); @@ -127,33 +127,33 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) { } static PyObject* THPFInfo_max(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(at::CPU(self->type), "max", [] { + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "max", [] { return PyFloat_FromDouble( std::numeric_limits::type>::max()); }); } static PyObject* THPFInfo_min(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(at::CPU(self->type), "min", [] { + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "min", [] { return PyFloat_FromDouble( std::numeric_limits::type>::lowest()); }); } static PyObject* THPIInfo_max(THPFInfo* self, void*) { - return AT_DISPATCH_INTEGRAL_TYPES(at::CPU(self->type), "max", [] { + return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max", [] { return THPUtils_packInt64(std::numeric_limits::max()); }); } static PyObject* THPIInfo_min(THPFInfo* self, void*) { - return AT_DISPATCH_INTEGRAL_TYPES(at::CPU(self->type), "min", [] { + return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min", [] { return THPUtils_packInt64(std::numeric_limits::lowest()); }); } static PyObject* THPFInfo_tiny(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(at::CPU(self->type), "min", [] { + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "min", [] { return PyFloat_FromDouble( std::numeric_limits::type>::min()); });