Convert mul to use opmath_gpu_kernel_with_scalars (#64019)
authorEdward Yang <ezyang@fb.com>
Wed, 1 Sep 2021 00:55:23 +0000 (17:55 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 01:33:30 +0000 (18:33 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64019

Note that previously the functor operated on scalar_t and
this modifies it to operate on opmath_t, but this is not
a problem as half precision was implemented by performing the
compute in float anyway.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30575282

Pulled By: ezyang

fbshipit-source-id: cc6900ef996e755740afe48f9cb4d0366858dd47

aten/src/ATen/native/cuda/BinaryMulDivKernel.cu

index da615fe..e6a5300 100644 (file)
 
 namespace at { namespace native {
 
-template<typename scalar_t, typename accscalar_t>
-struct MulScalarFunctor {
-    MulScalarFunctor(accscalar_t b_): b(b_) {}
-    __device__ scalar_t operator() (scalar_t a) const {
-      return a * b;
-    }
-  private:
-    accscalar_t b;
-};
-
 template<typename scalar_t>
 struct DivFunctor {
   __device__ scalar_t operator() (scalar_t a, scalar_t b) const {
@@ -31,9 +21,9 @@ struct DivFunctor {
   }
 };
 
-template<typename scalar_t>
+template<typename T>
 struct MulFunctor {
-  __device__ scalar_t operator() (scalar_t a, scalar_t b) const {
+  __device__ T operator() (T a, T b) const {
     return a * b;
   }
 };
@@ -53,11 +43,11 @@ void div_true_kernel_cuda(TensorIteratorBase& iter) {
     // scalar, compute a * reciprocal(b). Note that this may lose one bit of
     // precision compared to computing the division.
     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_true_cuda", [&]() {
-      using accscalar_t = at::acc_type<scalar_t, true>;
-      auto inv_b = accscalar_t(1.0) / iter.scalar_value<accscalar_t>(2);
+      using opmath_t = at::opmath_type<scalar_t>;
+      auto inv_b = opmath_t(1.0) / iter.scalar_value<opmath_t>(2);
       iter.remove_operand(2);
-      MulScalarFunctor<scalar_t, decltype(inv_b)> f(inv_b);
-      gpu_kernel(iter, f);
+      gpu_kernel(iter, BUnaryFunctor<scalar_t, scalar_t, scalar_t, MulFunctor<opmath_t>>(
+        MulFunctor<opmath_t>(), inv_b));
     });
   } else {
     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_true_cuda", [&]() {
@@ -180,25 +170,10 @@ void div_floor_kernel_cuda(TensorIteratorBase& iter) {
 }
 
 void mul_kernel_cuda(TensorIteratorBase& iter) {
-  if (!isIntegralType(iter.common_dtype(), /*includeBool*/ true) &&
-    (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) {
-    //if common dtype is half the scalar constant can overflow in half precision, and yet the result can
-    //still be representable in the half dtype. Cast scalar to acc_type to have better accuracy
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "mul_cuda", [&]() {
-      using accscalar_t = at::acc_type<scalar_t, true>;
-      int scalar_arg = iter.is_cpu_scalar(1) ? 1 : 2;
-      auto b = iter.scalar_value<accscalar_t>(scalar_arg);
-      iter.remove_operand(scalar_arg);
-      const cuda::OptionalCUDAGuard device_guard(device_of(iter.tensor(1)));
-      MulScalarFunctor<scalar_t, decltype(b)> f(b);
-      gpu_kernel(iter, f);
-    });
-  } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() {
-      MulFunctor<scalar_t> f;
-      gpu_kernel_with_scalars(iter, f);
-    });
-  }
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() {
+    using opmath_t = at::opmath_type<scalar_t>;
+    opmath_gpu_kernel_with_scalars<scalar_t>(iter, MulFunctor<opmath_t>());
+  });
 }
 
 REGISTER_DISPATCH(div_true_stub, &div_true_kernel_cuda);