Add acc_gpu_kernel_with_scalars and port add to use it (#63884)
authorEdward Yang <ezyang@fb.com>
Tue, 31 Aug 2021 02:08:45 +0000 (19:08 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 02:10:16 +0000 (19:10 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63884

See https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302
for explanation of what's going on here.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30545296

Pulled By: ezyang

fbshipit-source-id: f0da52153ae63599fe1d57e90e73f50ca2116939

aten/src/ATen/native/cuda/BinaryAddSubKernel.cu
aten/src/ATen/native/cuda/Loops.cuh

index a07fd66..b1c76e1 100644 (file)
 
 namespace at { namespace native {
 
-template<typename scalar_t, typename accscalar_t>
+template <typename T>
 struct AddFunctor {
-  AddFunctor(accscalar_t a): alpha(a) {}
-  __device__ __forceinline__ scalar_t operator() (const scalar_t a, const scalar_t b) const {
-    return a + alpha * b;
+  AddFunctor(T alpha) : alpha_(alpha) {}
+  T alpha_;
+  __device__ __forceinline__ T operator()(T a, T b) const __ubsan_ignore_undefined__ {
+    return a + b * alpha_;
   }
-  private:
-    accscalar_t alpha;
-};
-
-template<typename scalar_t, typename accscalar_t, int SCALAR_ARG>
-struct AddScalarFunctor {
-  static_assert(SCALAR_ARG == 1 || SCALAR_ARG == 2, "SCALAR_ARG must be either 1 or 2");
-  AddScalarFunctor(accscalar_t alpha, accscalar_t b): alpha(alpha), b(b) {}
-  __device__ __forceinline__ scalar_t operator() (const scalar_t a) const {
-    return static_cast<scalar_t>(SCALAR_ARG == 1 ? b + alpha * a : a + alpha * b);
-  }
-  private:
-    accscalar_t alpha;
-    accscalar_t b;
 };
 
 void add_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) {
-  if (!isIntegralType(iter.common_dtype(), /* includeBool */ true) && (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) {
-    // if common dtype is half the scalar constant can overflow in half precision, and yet the result can
-    // still be representable in the half dtype. Cast scalar to acc_type to have better accuracy.
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
-      using accscalar_t = at::acc_type<scalar_t, true>;
-      int scalar_arg = iter.is_cpu_scalar(1) ? 1 : 2;
-      auto b = iter.scalar_value<accscalar_t>(scalar_arg);
-      iter.remove_operand(scalar_arg);
-      const cuda::OptionalCUDAGuard device_guard(device_of(iter.tensor(1)));
-      if (scalar_arg == 1) {
-        AddScalarFunctor<scalar_t, decltype(b), 1> f(alpha_scalar.to<accscalar_t>(), b);
-        gpu_kernel(iter, f);
-      } else {
-        AddScalarFunctor<scalar_t, decltype(b), 2> f(alpha_scalar.to<accscalar_t>(), b);
-        gpu_kernel(iter, f);
-      }
-    });
-  } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
-      using accscalar_t = at::acc_type<scalar_t, true>;
-      AddFunctor<scalar_t, accscalar_t> f(alpha_scalar.to<accscalar_t>());
-      gpu_kernel_with_scalars(iter, f);
-    });
-  }
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
+    using opmath_t = at::opmath_type<scalar_t>;
+    opmath_gpu_kernel_with_scalars<scalar_t>(iter, AddFunctor<opmath_t>(alpha_scalar.to<opmath_t>()));
+  });
 }
 
 static void sub_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) {
index fde8e86..8849293 100644 (file)
@@ -5,6 +5,7 @@
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/TensorIteratorDynamicCasting.h>
 #include <ATen/cuda/detail/OffsetCalculator.cuh>
+#include <ATen/OpMathType.h>
 
 #include <thrust/tuple.h>
 
@@ -111,49 +112,64 @@ void gpu_kernel(TensorIteratorBase& iter, const func_t& f) {
   gpu_kernel_impl(iter, f);
 }
 
-template<typename func_t>
+template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
 struct AUnaryFunctor {
   using traits = function_traits<func_t>;
-  using arg1_t = typename traits::template arg<0>::type;
-  using arg2_t = typename traits::template arg<1>::type;
-  using return_t = typename traits::result_type;
+  using opmath_arg1_t = typename traits::template arg<0>::type;
   __device__ return_t operator()(arg2_t b) const {
     return f(a, b);
   }
-  AUnaryFunctor(func_t f_, arg1_t a_): f(f_), a(a_) {}
+  // NB: scalar is stored in higher precision!
+  AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {}
   private:
     func_t f;
-    arg1_t a;
+    opmath_arg1_t a;
 };
 
-template<typename func_t>
+template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
 struct BUnaryFunctor {
   using traits = function_traits<func_t>;
-  using arg1_t = typename traits::template arg<0>::type;
-  using arg2_t = typename traits::template arg<1>::type;
-  using return_t = typename traits::result_type;
+  using opmath_arg2_t = typename traits::template arg<1>::type;
   __device__ return_t operator()(arg1_t a) const {
     return f(a, b);
   }
-  BUnaryFunctor(func_t f_, arg2_t b_): f(f_), b(b_) {}
+  // NB: scalar is stored in higher precision!
+  BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {}
   private:
     func_t f;
-    arg2_t b;
+    opmath_arg2_t b;
 };
 
-template <typename func_t>
-void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
+// Though seemingly noop, this inserts casts from arg1_t to func_t's type
+// (which may be higher precision), as well as casts to return_t
+template <typename arg1_t, typename arg2_t, typename return_t, typename func_t>
+struct BinaryFunctor {
+  __device__ return_t operator()(arg1_t a, arg2_t b) const {
+    return f(a, b);
+  }
+  BinaryFunctor(func_t f_): f(f_) {}
+  private:
+    func_t f;
+};
+
+// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which
+// accepts inputs at higher precision (typically opmath_t), but then
+// ensure that we load from memory at the correct precision (scalar_t)
+// to avoid expensive loads.  For the whole sordid story see
+// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302
+template <typename arg1_t, typename arg2_t = arg1_t, typename return_t = arg1_t, typename func_t>
+void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
   TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
 
   using traits = function_traits<func_t>;
+  using opmath_arg1_t = typename traits::template arg<0>::type;
+  using opmath_arg2_t = typename traits::template arg<1>::type;
   static_assert(
       traits::arity == 2,
       "gpu_kernel_with_scalars only supports two input arguments");
 
-  using arg1_t = typename traits::template arg<0>::type;
-  using arg2_t = typename traits::template arg<1>::type;
   if (iter.is_cpu_scalar(1)) {
-    AUnaryFunctor<func_t> af(f, iter.scalar_value<arg1_t>(1));
+    AUnaryFunctor<arg1_t, arg2_t, return_t, func_t> af(f, iter.scalar_value<opmath_arg1_t>(1));
     iter.remove_operand(1);
     // TODO: When all kernels that use gpu_kernel_with_scalars are
     // ported to structured, this device guard can be deleted.  This
@@ -163,14 +179,28 @@ void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
     const OptionalDeviceGuard device_guard(device_of(iter.tensor(1)));
     gpu_kernel(iter, af);
   } else if (iter.is_cpu_scalar(2)) {
-    BUnaryFunctor<func_t> bf(f, iter.scalar_value<arg2_t>(2));
+    BUnaryFunctor<arg1_t, arg2_t, return_t, func_t> bf(f, iter.scalar_value<opmath_arg2_t>(2));
     iter.remove_operand(2);
     gpu_kernel(iter, bf);
   } else {
-    gpu_kernel(iter, f);
+    gpu_kernel(iter, BinaryFunctor<arg1_t, arg2_t, return_t, func_t>(f));
   }
 }
 
+// Legacy variant that assumes that func_t has the correct types
+// that we expect to load from memory
+template <typename func_t>
+void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
+  using traits = function_traits<func_t>;
+  static_assert(
+      traits::arity == 2,
+      "gpu_kernel_with_scalars only supports two input arguments");
+  using arg1_t = typename traits::template arg<0>::type;
+  using arg2_t = typename traits::template arg<1>::type;
+  using return_t = typename traits::result_type;
+  opmath_gpu_kernel_with_scalars<arg1_t, arg2_t, return_t, func_t>(iter, f);
+}
+
 namespace { // functions for `gpu_kernel_multiple_outputs`.
 
 // check the return type is `thrust::tuple`, not `std::tuple`.