namespace at { namespace native {
-template<typename scalar_t, typename accscalar_t>
-struct MulScalarFunctor {
- MulScalarFunctor(accscalar_t b_): b(b_) {}
- __device__ scalar_t operator() (scalar_t a) const {
- return a * b;
- }
- private:
- accscalar_t b;
-};
-
template<typename scalar_t>
struct DivFunctor {
__device__ scalar_t operator() (scalar_t a, scalar_t b) const {
}
};
-template<typename scalar_t>
+template<typename T>
struct MulFunctor {
- __device__ scalar_t operator() (scalar_t a, scalar_t b) const {
+ __device__ T operator() (T a, T b) const {
return a * b;
}
};
// scalar, compute a * reciprocal(b). Note that this may lose one bit of
// precision compared to computing the division.
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_true_cuda", [&]() {
- using accscalar_t = at::acc_type<scalar_t, true>;
- auto inv_b = accscalar_t(1.0) / iter.scalar_value<accscalar_t>(2);
+ using opmath_t = at::opmath_type<scalar_t>;
+ auto inv_b = opmath_t(1.0) / iter.scalar_value<opmath_t>(2);
iter.remove_operand(2);
- MulScalarFunctor<scalar_t, decltype(inv_b)> f(inv_b);
- gpu_kernel(iter, f);
+ gpu_kernel(iter, BUnaryFunctor<scalar_t, scalar_t, scalar_t, MulFunctor<opmath_t>>(
+ MulFunctor<opmath_t>(), inv_b));
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_true_cuda", [&]() {
}
void mul_kernel_cuda(TensorIteratorBase& iter) {
- if (!isIntegralType(iter.common_dtype(), /*includeBool*/ true) &&
- (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) {
- //if common dtype is half the scalar constant can overflow in half precision, and yet the result can
- //still be representable in the half dtype. Cast scalar to acc_type to have better accuracy
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "mul_cuda", [&]() {
- using accscalar_t = at::acc_type<scalar_t, true>;
- int scalar_arg = iter.is_cpu_scalar(1) ? 1 : 2;
- auto b = iter.scalar_value<accscalar_t>(scalar_arg);
- iter.remove_operand(scalar_arg);
- const cuda::OptionalCUDAGuard device_guard(device_of(iter.tensor(1)));
- MulScalarFunctor<scalar_t, decltype(b)> f(b);
- gpu_kernel(iter, f);
- });
- } else {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() {
- MulFunctor<scalar_t> f;
- gpu_kernel_with_scalars(iter, f);
- });
- }
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() {
+ using opmath_t = at::opmath_type<scalar_t>;
+ opmath_gpu_kernel_with_scalars<scalar_t>(iter, MulFunctor<opmath_t>());
+ });
}
REGISTER_DISPATCH(div_true_stub, &div_true_kernel_cuda);