(#14580)
authorJie <jiej@nvidia.com>
Thu, 6 Dec 2018 16:57:39 +0000 (08:57 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 6 Dec 2018 17:03:46 +0000 (09:03 -0800)
Summary:
Removes cast of half to float in torch.sum, with float16 input tensor and
float32 output tensor, instead we cast data when loading input in kernel.

This supposingly would save a kernel launch as well as a full global memory load
on promoted data type (float).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14580

Differential Revision: D13356203

Pulled By: ezyang

fbshipit-source-id: 85e91225b880a65fe3ceb493371b9b36407fdf48

aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/TensorIterator.cpp
aten/src/ATen/native/TensorIterator.h
aten/src/ATen/native/cuda/Reduce.cuh
aten/src/ATen/native/cuda/ReduceOpsKernel.cu
test/test_cuda.py

index 5e297c8..e23feb6 100644 (file)
@@ -95,10 +95,16 @@ static std::unique_ptr<TensorIterator> make_reduction(
   auto mask = make_dim_mask(dim, ndim);
   allocate_reduction_result(result, self, mask, keepdim, dtype);
   auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
-  if (self.type().scalarType() != dtype) {
-    return TensorIterator::reduce_op(viewed_result, self.to(dtype));
+
+  // special case for type promotion in mixed precision, improves computational
+  // efficiency.
+  // not generalize this to common mismatched input/output types to avoid cross
+  // product of templated kernel launches.
+  if (self.type().scalarType() == dtype || 
+      (self.is_cuda() && self.type().scalarType() == kHalf && dtype == kFloat)) {
+    return TensorIterator::reduce_op(viewed_result, self);
   }
-  return TensorIterator::reduce_op(viewed_result, self);
+  return TensorIterator::reduce_op(viewed_result, self.to(dtype));
 }
 
 static inline int64_t n_dim_size(const Tensor& self, IntList dim) {
index 2933c0b..27a758f 100644 (file)
@@ -115,6 +115,13 @@ void TensorIterator::compute_types() {
             type.device_type() == kCUDA && op.tensor.type().device_type() == kCPU) {
           // don't cast CPU scalars in CUDA ops that directly support them
           op.type = &op.tensor.type();
+        } else if (promote_gpu_output_dtypes_ && op.tensor.defined() &&
+            !op.is_output && op.tensor.type().scalarType() == kHalf &&
+            type.scalarType() == kFloat && type.device_type() == kCUDA &&
+            op.tensor.type().device_type() == kCUDA) {
+          // allow input tensor type upcasting for fp16 to fp32 in fused kernel
+          // on GPU
+          op.type = &op.tensor.type();
         } else {
           op.type = &type;
         }
@@ -475,6 +482,7 @@ std::unique_ptr<TensorIterator> TensorIterator::reduce_op(Tensor& out, const Ten
   auto builder = TensorIterator::Builder();
   builder.add_output(out);
   builder.add_input(a);
+  builder.iter_->promote_gpu_output_dtypes_ = true;
   builder.iter_->resize_outputs_ = false;
   builder.iter_->is_reduction_ = true;
   return builder.build();
index 62d9a4b..4c36ed9 100644 (file)
@@ -228,6 +228,7 @@ protected:
   bool is_reduction_ = false;
   bool compute_common_dtype_ = true;
   bool allow_cpu_scalars_ = false;
+  bool promote_gpu_output_dtypes_ = false;
 };
 
 struct TensorIterator::Builder {
index c67c81f..7635341 100644 (file)
@@ -197,7 +197,7 @@ __device__ Array<type_t, vt> load_memory(const type_t* in, int begin, int end, i
   return load_memory<vt>(in, begin, end, stride, [](int idx) { return idx; });
 }
 
-template <typename scalar_t, typename func_t>
+template <typename scalar_t, typename func_t, typename out_scalar_t=scalar_t>
 struct ReduceOp {
   using traits = binary_function_traits<func_t>;
   using arg_t = typename traits::arg2_t;
@@ -248,7 +248,7 @@ struct ReduceOp {
       value = warp_reduce(value);
     }
 
-    auto out = (scalar_t*)((char*)dst + base_offsets[0]);
+    auto out = (out_scalar_t*)((char*)dst + base_offsets[0]);
     if (config.should_global_reduce()) {
       value = global_reduce(value, out);
     } else if (config.should_store(output_idx)) {
@@ -335,7 +335,7 @@ struct ReduceOp {
     return is_last_block_done;
   }
 
-  C10_DEVICE arg_t global_reduce(arg_t value, scalar_t* out) const {
+  C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out) const {
     arg_t* reduce_buffer = (arg_t*)buffer;
 
     bool should_store = config.should_store(config.output_idx());
@@ -393,14 +393,14 @@ static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction)
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
-template <typename scalar_t, typename func_t, typename ident_t=double>
+template <typename scalar_t, typename out_scalar_t, typename func_t, typename ident_t=double>
 inline void gpu_reduce_kernel(TensorIterator& iter, const func_t& op, ident_t ident=0) {
   ASSERT_HOST_DEVICE_LAMBDA(func_t);
   AT_ASSERT(iter.numel() > 0 && iter.ntensors() == 2);
 
   if (!iter.can_use_32bit_indexing()) {
     for (auto& sub_iter : iter.with_32bit_indexing()) {
-      gpu_reduce_kernel<scalar_t>(sub_iter, op);
+      gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, op);
     }
     return;
   }
@@ -465,7 +465,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const func_t& op, ident_t id
     auto stream = at::cuda::getCurrentCUDAStream();
     AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
   }
-  auto reduce = ReduceOp<scalar_t, func_t>(
+  auto reduce = ReduceOp<scalar_t, func_t, out_scalar_t>(
       op,
       config,
       input_calc,
index 6a1b008..71e21db 100644 (file)
@@ -11,9 +11,9 @@
 
 namespace at { namespace native {
 
-template <typename scalar_t, typename acc_t=scalar_t>
+template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
 void sum_kernel_impl(TensorIterator& iter) {
-  gpu_reduce_kernel<scalar_t>(iter, []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+  gpu_reduce_kernel<scalar_t, out_t>(iter, []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
     return a + b;
   });
 }
@@ -25,7 +25,7 @@ void sum_kernel_impl<int16_t, int16_t>(TensorIterator& iter) {
   // compiler segfaults:
   // https://bugs.llvm.org/show_bug.cgi?id=39602
   // To work around it, use int32 as the accumulate type.
-  gpu_reduce_kernel<int16_t>(iter, []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
+  gpu_reduce_kernel<int16_t, int16_t>(iter, []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
     return a + b;
   });
 }
@@ -33,7 +33,7 @@ void sum_kernel_impl<int16_t, int16_t>(TensorIterator& iter) {
 
 template <typename scalar_t, typename acc_t=scalar_t>
 void prod_kernel_impl(TensorIterator& iter) {
-  gpu_reduce_kernel<scalar_t>(iter, []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+  gpu_reduce_kernel<scalar_t, scalar_t>(iter, []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
     return a * b;
   }, 1);
 }
@@ -41,6 +41,9 @@ void prod_kernel_impl(TensorIterator& iter) {
 static void sum_kernel_cuda(TensorIterator& iter) {
   if (iter.type().scalarType() == kHalf) {
     return sum_kernel_impl<at::Half, float>(iter);
+  } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) {
+    // type promotion that does cast and reduction in a single kernel
+    return sum_kernel_impl<at::Half, float, float>(iter);
   }
   AT_DISPATCH_ALL_TYPES(iter.type(), "sum", [&]() {
     sum_kernel_impl<scalar_t>(iter);
index 5249f36..1f8b2ad 100644 (file)
@@ -1541,6 +1541,21 @@ class TestCuda(TestCase):
         self.assertEqual(gpu_tensor0[0], 2)
 
     @skipIfRocm
+    def test_sum_cpu_gpu_mismatch(self):
+        x = torch.randn(20, dtype=torch.float32, device='cuda')
+        y = torch.randn(1, dtype=torch.float32)
+        with self.assertRaisesRegex(RuntimeError, 'expected type'
+                                    ' torch.FloatTensor but got'
+                                    ' torch.cuda.FloatTensor'):
+            torch.sum(x, dim=[0], dtype=torch.float32, out=y)
+        # makeing sure half to float promotion is also properly working.
+        x = x.half()
+        with self.assertRaisesRegex(RuntimeError, 'expected type'
+                                    ' torch.FloatTensor but got'
+                                    ' torch.cuda.HalfTensor'):
+            torch.sum(x, dim=[0], dtype=torch.float32, out=y)
+
+    @skipIfRocm
     def test_sum_noncontig(self):
         x = torch.randn(1, 75, 57, 20, device='cuda').permute(0, 3, 1, 2)
         y = x.cpu()
@@ -1555,6 +1570,7 @@ class TestCuda(TestCase):
 
         x = torch.ones(65504, device='cuda', dtype=torch.float16)
         self.assertEqual(x.sum(), 65504)
+        self.assertEqual(x.sum(dtype=torch.float32), 65504)
 
         a = torch.zeros(1203611).bernoulli_(0.0005)
         x = a.to(device='cuda', dtype=torch.float16)