[TensorIterator fixing mean to output correct result for half precisi… (#14878)
authorJie <jiej@nvidia.com>
Tue, 18 Dec 2018 04:08:15 +0000 (20:08 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 18 Dec 2018 04:13:30 +0000 (20:13 -0800)
Summary:
…on](#12115)

mean is calculated in two step sum()/numel(). For half precision, data gets
casted back to half after sum().
We fused the division into the reduction kernel by adding pre_op/post_op.

This allows us to do torch.ones(65536).cuda().half().mean() to return correct
result.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14878

Differential Revision: D13491159

Pulled By: soumith

fbshipit-source-id: e83802e1628b6d2615c45e18d7acf991d143a09e

aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/ReduceOps.h
aten/src/ATen/native/TensorIterator.cpp
aten/src/ATen/native/cpu/Reduce.h
aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
aten/src/ATen/native/cuda/Reduce.cuh
aten/src/ATen/native/cuda/ReduceOpsKernel.cu
test/test_cuda.py
tools/autograd/derivatives.yaml

index 96345c7..e4b3f89 100644 (file)
@@ -24,6 +24,7 @@ DEFINE_DISPATCH(sum_stub);
 DEFINE_DISPATCH(std_stub);
 DEFINE_DISPATCH(prod_stub);
 DEFINE_DISPATCH(norm_kernel);
+DEFINE_DISPATCH(mean_stub);
 
 static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
   ScalarType scalarType = self.type().scalarType();
@@ -183,29 +184,6 @@ Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim) {
 
 // ALL REDUCE #################################################################
 
-static inline Tensor mean(const Tensor &self, optional<ScalarType> dtype) {
-  ScalarType scalarType = self.type().scalarType();
-  AT_CHECK(
-      at::isFloatingType(scalarType),
-      "Can only calculate the mean of floating types. Got ",
-      toString(scalarType),
-      " instead.");
-  if (self.numel() > 0) {
-    Tensor result = at::native::sum(self);
-    return result.div_(self.numel());
-  } else {
-    return at::scalar_tensor(std::numeric_limits<double>::quiet_NaN(), self.options());
-  }
-}
-
-Tensor mean(const Tensor &self, ScalarType dtype) {
-  return at::native::mean(self, optional<ScalarType>(dtype));
-}
-
-Tensor mean(const Tensor &self) {
-  return at::native::mean(self, c10::nullopt);
-}
-
 static ScalarType get_dtype(Tensor& result, const Tensor& self, optional<ScalarType> dtype,
                             bool promote_integers=false) {
   if (dtype.has_value()) {
@@ -272,32 +250,28 @@ Tensor prod(const Tensor &self) {
   return at::native::prod(self, {}, false, c10::nullopt);
 }
 
-// \ALL REDUCE ################################################################
-
-// DIM REDUCE #################################################################
-
 static inline Tensor &mean_out(Tensor &result, const Tensor &self, IntList dim,
-                 bool keepdim, optional<ScalarType> dtype) {
-  ScalarType scalarType = result.type().scalarType();
+                 bool keepdim, optional<ScalarType> opt_dtype) {
+  ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.type().scalarType();
   AT_CHECK(
       at::isFloatingType(scalarType),
       "Can only calculate the mean of floating types. Got ",
       toString(scalarType),
       " instead.");
-  at::native::sum_out(
-      result, self.toType(result.type().scalarType()), dim, keepdim);
-  if (result.numel() > 0 && self.ndimension() > 0) {
-    int64_t numel = n_dim_size(self, dim);
-    if (numel > 0) {
-      result.div_(numel);
-    } else {
-      // NumPy equivalent
-      result.fill_(std::numeric_limits<double>::quiet_NaN());
-    }
+  ScalarType dtype = get_dtype(result, self, opt_dtype, true);
+  auto iter = make_reduction("mean", result, self, dim, keepdim, dtype);
+  if (iter->numel() == 0) {
+    result.fill_(std::numeric_limits<double>::quiet_NaN());
+  } else {
+    mean_stub(iter->device_type(), *iter);
   }
   return result;
 }
 
+// \ALL REDUCE ################################################################
+
+// DIM REDUCE #################################################################
+
 Tensor& mean_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
   return at::native::mean_out(
       result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
@@ -310,6 +284,23 @@ Tensor& mean_out(Tensor& result, const Tensor& self, IntList dim, ScalarType dty
   return at::native::mean_out(result, self, dim, false, dtype);
 }
 
+static inline Tensor mean(const Tensor &self, IntList dim, bool keepdim, optional<ScalarType> dtype) {
+  Tensor result;
+  return at::native::mean_out(result, self, dim, keepdim, dtype);
+}
+
+static inline Tensor mean(const Tensor &self, optional<ScalarType> dtype) {
+  return at::native::mean(self, {}, false, dtype);
+}
+
+Tensor mean(const Tensor &self, ScalarType dtype) {
+  return at::native::mean(self, optional<ScalarType>(dtype));
+}
+
+Tensor mean(const Tensor &self) {
+  return at::native::mean(self, c10::nullopt);
+}
+
 Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
   return at::native::sum_out(
       result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
@@ -336,26 +327,6 @@ Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dty
   return at::native::prod_out(result, self, dim, false, dtype);
 }
 
-static inline Tensor mean(const Tensor &self, IntList dim, bool keepdim, optional<ScalarType> dtype) {
-  ScalarType scalarType = self.type().scalarType();
-  AT_CHECK(
-      at::isFloatingType(scalarType),
-      "Can only calculate the mean of floating types. Got ",
-      toString(scalarType),
-      " instead.");
-  Tensor result = at::native::sum(self, dim, keepdim);
-  if (result.numel() > 0 && self.ndimension() > 0) {
-    int64_t numel = n_dim_size(self, dim);
-    if (numel > 0) {
-      result.div_(numel);
-    } else {
-      // NumPy equivalent
-      result.fill_(std::numeric_limits<double>::quiet_NaN());
-    }
-  }
-  return result;
-}
-
 Tensor mean(const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
   return at::native::mean(self, dim, keepdim, c10::optional<ScalarType>(dtype));
 }
index 62b8a1d..9bc83d2 100644 (file)
@@ -14,6 +14,7 @@ using reduce_fn = void(*)(TensorIterator &);
 
 DECLARE_DISPATCH(reduce_fn, sum_stub);
 DECLARE_DISPATCH(reduce_fn, prod_stub);
+DECLARE_DISPATCH(reduce_fn, mean_stub);
 
 using reduce_std_function =
   void (*)(TensorIterator&, bool unbiased);
index 4570833..c4cf591 100644 (file)
@@ -120,11 +120,11 @@ void TensorIterator::compute_types() {
     if (op.tensor.defined() && op.tensor.type() != *op.type) {
       if (op.is_output) {
         AT_ERROR("output with type ", op.tensor.type().toString(),
-                 " doesn't match the desired type ", type().toString());
+                 " doesn't match the desired type ", op.type->toString());
       } else if (op.tensor.dim() == 0) {
         op.tensor = op.tensor.to(*op.type);
       } else {
-        AT_ERROR("expected type ", type().toString(), " but got ",
+        AT_ERROR("expected type ", op.type->toString(), " but got ",
             op.tensor.type().toString());
       }
     }
index 2b83585..9e22d5e 100644 (file)
@@ -56,8 +56,9 @@ struct all_same : c10::guts::conjunction<
 
 template <typename rf_t,
           typename cf_t,
-          typename pf_t>
-void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &combine, pf_t const &project) {
+          typename pf_t,
+          typename init_t>
+void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &combine, pf_t const &project, init_t init) {
   using r_traits = binary_function_traits<rf_t>;
   using c_traits = binary_function_traits<cf_t>;
   using p_traits = unary_function_traits<pf_t>;
@@ -66,6 +67,7 @@ void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &
   static_assert(
     all_same<
       acc_t,
+      init_t,
       typename r_traits::arg1_t,
       typename r_traits::result_type,
       typename c_traits::arg1_t,
@@ -84,16 +86,17 @@ void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &
     bool serial = numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region();
     int max_threads = serial ? 1 : at::get_max_threads();
     AT_ASSERT(max_threads > 0);
-    std::vector<optional<acc_t>> buffer{(unsigned)max_threads, optional<acc_t> {}};
+    std::vector<optional<acc_t>> buffer((unsigned)max_threads, optional<acc_t> {});
     at::parallel_for(0, numel, serial ? (1 + numel) : internal::GRAIN_SIZE,
     [&](int64_t begin, int64_t end) {
       auto &acc = buffer[at::get_thread_num()];
-      sub_iter.serial_for_each([&acc, &reduce](int ntensors, char** data, const int64_t* strides, int64_t size) {
+      sub_iter.serial_for_each([&acc, &reduce, &init](int ntensors, char** data, const int64_t* strides, int64_t size) {
         AT_ASSERT(ntensors == 2);
         char *in = data[1];
         int64_t stride = strides[1];
         if (!acc && size > 0) {
-          acc = acc_t {};
+          //acc = acc_t {};
+          acc = init;
         }
         for (int64_t i = 0; i < size; ++i) {
           acc = reduce(*acc, *(data_t*)in);
@@ -101,7 +104,7 @@ void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &
         }
       }, {begin, end});
     });
-    acc_t acc;
+    acc_t acc = init;
     for (int i = 0; i < max_threads; ++i) {
       if (buffer[i]) {
         acc = combine(acc, *buffer[i]);
index 4f3f2e1..983e894 100644 (file)
@@ -22,6 +22,17 @@ static void sum_kernel_impl(TensorIterator& iter) {
   });
 }
 
+static void mean_kernel_impl(TensorIterator& iter) {
+  AT_DISPATCH_ALL_TYPES(iter.type(), "mean", [&] {
+    scalar_t factor = scalar_t(iter.num_output_elements()) / iter.numel();
+    binary_kernel_reduce(
+      iter,
+      [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
+      [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
+      [factor](scalar_t a) -> scalar_t { return a*factor; }, scalar_t(0));
+  });
+}
+
 struct WelfordData {
   double mean;
   double m2;
@@ -63,7 +74,8 @@ static void std_kernel_impl(TensorIterator &iter, bool unbiased) {
       [unbiased](WelfordData acc) -> scalar_t {
         int64_t divisor = unbiased ? (acc.n - 1) : acc.n;
         return (divisor > 0) ? std::sqrt(acc.m2 / divisor) : NAN;
-      }
+      },
+      WelfordData()
     );
   });
 }
@@ -253,6 +265,6 @@ REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
 REGISTER_DISPATCH(std_stub, &std_kernel_impl);
 REGISTER_DISPATCH(prod_stub, &prod_kernel_impl);
 REGISTER_DISPATCH(norm_kernel, &norm_kernel_impl);
+REGISTER_DISPATCH(mean_stub, &mean_kernel_impl);
 
 }}  // namespace at::native
-
index 7635341..989fd26 100644 (file)
@@ -197,7 +197,8 @@ __device__ Array<type_t, vt> load_memory(const type_t* in, int begin, int end, i
   return load_memory<vt>(in, begin, end, stride, [](int idx) { return idx; });
 }
 
-template <typename scalar_t, typename func_t, typename out_scalar_t=scalar_t>
+template <typename scalar_t, typename func_t, typename pre_func_t,
+          typename post_func_t, typename out_scalar_t=scalar_t>
 struct ReduceOp {
   using traits = binary_function_traits<func_t>;
   using arg_t = typename traits::arg2_t;
@@ -208,6 +209,8 @@ struct ReduceOp {
   static constexpr int vt0 = 4;
 
   func_t op;
+  pre_func_t pre_op;
+  post_func_t post_op;
   arg_t ident;
   ReduceConfig config;
   InputCalculator input_calc;
@@ -219,8 +222,11 @@ struct ReduceOp {
   bool accumulate;
 
   ReduceOp(func_t op, ReduceConfig config, InputCalculator input_calc, OutputCalculator output_calc,
-           const void* src, void* dst, void* buffer, int* semaphores)
+           const void* src, void* dst, void* buffer, int* semaphores, pre_func_t pre_op,
+           post_func_t post_op)
     : op(op)
+    , pre_op(pre_op)
+    , post_op(post_op)
     , config(config)
     , input_calc(input_calc)
     , output_calc(output_calc)
@@ -252,6 +258,7 @@ struct ReduceOp {
     if (config.should_global_reduce()) {
       value = global_reduce(value, out);
     } else if (config.should_store(output_idx)) {
+      value = post_op(value);
       if (accumulate) {
         value = op(*out, value);
       }
@@ -278,7 +285,7 @@ struct ReduceOp {
 
     arg_t value;
     strided_iterate<vt0>([&](int i, int idx) {
-      value = i == 0 ? (arg_t)values[0] : op(value, values[i]);
+      value = i == 0 ? pre_op(values[0]) : op(value, pre_op(values[i]));
     }, offset, config.num_inputs, config.step_input);
 
     return value;
@@ -371,6 +378,7 @@ struct ReduceOp {
       if (config.should_warp_reduce()) {
         value = warp_reduce(value);
       }
+      value = post_op(value);
       if (should_store) {
         if (accumulate) {
           value = op(*out, value);
@@ -393,14 +401,17 @@ static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction)
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
-template <typename scalar_t, typename out_scalar_t, typename func_t, typename ident_t=double>
-inline void gpu_reduce_kernel(TensorIterator& iter, const func_t& op, ident_t ident=0) {
+template <typename scalar_t, typename out_scalar_t, typename func_t, typename pre_func_t,
+          typename post_func_t, typename ident_t=double>
+inline void gpu_reduce_kernel(TensorIterator& iter, const pre_func_t &pre_op,
+                              const post_func_t &post_op, const func_t& op,
+                              ident_t ident=0) {
   ASSERT_HOST_DEVICE_LAMBDA(func_t);
   AT_ASSERT(iter.numel() > 0 && iter.ntensors() == 2);
 
   if (!iter.can_use_32bit_indexing()) {
     for (auto& sub_iter : iter.with_32bit_indexing()) {
-      gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, op);
+      gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, pre_op, post_op, op);
     }
     return;
   }
@@ -465,7 +476,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const func_t& op, ident_t id
     auto stream = at::cuda::getCurrentCUDAStream();
     AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
   }
-  auto reduce = ReduceOp<scalar_t, func_t, out_scalar_t>(
+  auto reduce = ReduceOp<scalar_t, func_t, pre_func_t, post_func_t, out_scalar_t>(
       op,
       config,
       input_calc,
@@ -473,7 +484,9 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const func_t& op, ident_t id
       in_data,
       out_data,
       buffer.get(),
-      (int*)semaphores.get());
+      (int*)semaphores.get(),
+      pre_op,
+      post_op);
   reduce.ident = ident;
   reduce.accumulate = iter.should_accumulate();
 
index 71e21db..283fc57 100644 (file)
 
 namespace at { namespace native {
 
+namespace {
+
+template <typename scalar_t>
+struct SimpleCopy {
+  __device__ __forceinline__ scalar_t operator() (const scalar_t a) const {
+    return a;
+  }
+};
+
+} // namespace
+
 template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
 void sum_kernel_impl(TensorIterator& iter) {
-  gpu_reduce_kernel<scalar_t, out_t>(iter, []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+  gpu_reduce_kernel<scalar_t, out_t>(iter, SimpleCopy<acc_t>(), SimpleCopy<acc_t>(),
+                                     []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
     return a + b;
   });
 }
@@ -25,7 +37,8 @@ void sum_kernel_impl<int16_t, int16_t>(TensorIterator& iter) {
   // compiler segfaults:
   // https://bugs.llvm.org/show_bug.cgi?id=39602
   // To work around it, use int32 as the accumulate type.
-  gpu_reduce_kernel<int16_t, int16_t>(iter, []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
+  gpu_reduce_kernel<int16_t, int16_t>(iter, SimpleCopy<int32_t>(), SimpleCopy<int32_t>(),
+                                      []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
     return a + b;
   });
 }
@@ -33,11 +46,34 @@ void sum_kernel_impl<int16_t, int16_t>(TensorIterator& iter) {
 
 template <typename scalar_t, typename acc_t=scalar_t>
 void prod_kernel_impl(TensorIterator& iter) {
-  gpu_reduce_kernel<scalar_t, scalar_t>(iter, []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+  gpu_reduce_kernel<scalar_t, scalar_t>(iter, SimpleCopy<acc_t>(), SimpleCopy<acc_t>(),
+                                        []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
     return a * b;
   }, 1);
 }
 
+template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
+void mean_kernel_impl(TensorIterator& iter) {
+  float factor = float(iter.num_output_elements()) / iter.numel();
+  gpu_reduce_kernel<scalar_t, out_t>(iter, SimpleCopy<acc_t>(), 
+      [factor]GPU_LAMBDA(acc_t a) -> acc_t { return a*factor; },
+      []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { return a + b; });
+}
+
+#ifdef __HIPCC__
+template <>
+void mean_kernel_impl<int16_t, int16_t, int16_t>(TensorIterator& iter) {
+  // There is a Register Coalescing bug in LLVM causing the hcc
+  // compiler segfaults:
+  // https://bugs.llvm.org/show_bug.cgi?id=39602
+  // To work around it, use int32 as the accumulate type.
+  float factor = float(iter.num_output_elements()) / iter.numel();
+  gpu_reduce_kernel<int16_t, int16_t>(iter, SimpleCopy<int32_t>(),
+      [factor]GPU_LAMBDA(int32_t a) -> int32_t { return a*factor; },
+      []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t { return a + b; });
+}
+#endif // __HIPCC__
+
 static void sum_kernel_cuda(TensorIterator& iter) {
   if (iter.type().scalarType() == kHalf) {
     return sum_kernel_impl<at::Half, float>(iter);
@@ -59,7 +95,20 @@ static void prod_kernel_cuda(TensorIterator& iter) {
   });
 }
 
+static void mean_kernel_cuda(TensorIterator& iter) {
+  if (iter.type().scalarType() == kHalf) {
+    return mean_kernel_impl<at::Half, float>(iter);
+  } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) {
+    // type promotion that does cast and reduction in a single kernel
+    return mean_kernel_impl<at::Half, float, float>(iter);
+  }
+  AT_DISPATCH_ALL_TYPES(iter.type(), "mean", [&]() {
+    mean_kernel_impl<scalar_t>(iter);
+  });
+}
+
 REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
 REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
+REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
 
 }} // namespace at::native
index d00d4e2..4d860cb 100644 (file)
@@ -1547,6 +1547,9 @@ class TestCuda(TestCase):
         self.assertEqual(x.sum(), 65504)
         self.assertEqual(x.sum(dtype=torch.float32), 65504)
 
+        x = torch.ones(65536, device='cuda', dtype=torch.float16)
+        self.assertEqual(x.sum(dtype=torch.float32), 65536)
+
         a = torch.zeros(1203611).bernoulli_(0.0005)
         x = a.to(device='cuda', dtype=torch.float16)
         self.assertEqual(x.sum().item(), a.sum().item())
@@ -1555,6 +1558,14 @@ class TestCuda(TestCase):
         x = a.to(device='cuda', dtype=torch.float16)
         self.assertEqual(x.sum((0, 2)).float().cpu(), a.sum((0, 2)))
 
+    @skipIfRocm
+    def test_mean_fp16(self):
+        x = torch.ones(65536, device='cuda', dtype=torch.float16)
+        self.assertEqual(x.mean(), 1)
+
+        x = torch.ones(65536, device='cuda', dtype=torch.float16)
+        self.assertEqual(x.mean(dtype=torch.float32), 1)
+
     @staticmethod
     def _select_broadcastable_dims(dims_full=None):
         return _TestTorchMixin._select_broadcastable_dims(dims_full)
index 42c89da..99fe5f4 100644 (file)
   self: grad.clone().masked_fill_(self <= other, 0)
   other: grad.clone().masked_fill_(self > other, 0)
 
+- name: mean(Tensor self)
+  self: grad.expand(self.sizes()) / self.numel()
+
+- name: mean(Tensor self, ScalarType dtype)
+  self: grad.expand(self.sizes()).to(self.type().scalarType()) / self.numel()
+
 - name: mean(Tensor self, IntList dim, bool keepdim)
   self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
 
-- name: mean(Tensor self)
-  self: grad.expand(self.sizes()) / self.numel()
+- name: mean(Tensor self, IntList dim, ScalarType dtype)
+  self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
+
+- name: mean(Tensor self, IntList dim, bool keepdim, ScalarType dtype)
+  self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / _safe_size(self.sizes(), dim)
 
 - name: median(Tensor self)
   self: select_equals_backward(grad, self, result)