multi-dim standard deviation for CUDA. (#14990)
authorBrennan Vincent <btv@fb.com>
Thu, 20 Dec 2018 16:53:44 +0000 (08:53 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 16:56:32 +0000 (08:56 -0800)
Summary:
This is the CUDA version of #14535 .
It refactors Reduce.cuh to allow more general classes of reductions to be performed -- we no longer assume that the temporary data returned during reduction is just one scalar, and instead allow an arbitrary accumulate type.
We also allow 64-bit indexing when necessary, since in general we will no longer be able to accumulate directly in the output. (In the cases when we can, we continue to split the tensors until they can be addressed with 32-bits, as before).
As an initial use-case, we implement `std` in multiple dimensions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14990

Differential Revision: D13405097

Pulled By: umanwizard

fbshipit-source-id: a56c24dc2fd5326d417632089bd3f5c4f9f0d2cb

12 files changed:
aten/src/ATen/cuda/detail/OffsetCalculator.cuh
aten/src/ATen/detail/FunctionTraits.h
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/SharedReduceOps.h [new file with mode: 0644]
aten/src/ATen/native/cpu/Reduce.h
aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
aten/src/ATen/native/cuda/DeviceSqrt.cuh [new file with mode: 0644]
aten/src/ATen/native/cuda/Normalization.cuh
aten/src/ATen/native/cuda/Reduce.cuh
aten/src/ATen/native/cuda/ReduceOpsKernel.cu
test/test_torch.py
torch/_torch_docs.py

index e8087a7..b750cf3 100644 (file)
@@ -9,20 +9,20 @@
 /// OffsetCalculator calculates the offset in bytes of a linear index for NARGS
 /// operands that share the same shape, but may have different strides.
 
-template <int NARGS>
+template <int NARGS, typename index_t = uint32_t>
 struct OffsetCalculator {
   static constexpr int MAX_DIMS = 25;
 
   // The offset for each argument (in bytes). Wrapper around fixed-size array.
-  using offset_type = at::cuda::Array<uint32_t, NARGS>;
+  using offset_type = at::cuda::Array<index_t, NARGS>;
 
   OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides) : dims(dims) {
     AT_CHECK(dims <= MAX_DIMS, "tensor has too many (>25) dims");
     for (int i = 0; i < MAX_DIMS; ++i) {
       if (i < dims) {
-        sizes_[i] = IntDivider<uint32_t>(sizes[i]);
+        sizes_[i] = IntDivider<index_t>(sizes[i]);
       } else {
-        sizes_[i] = IntDivider<uint32_t>(1);
+        sizes_[i] = IntDivider<index_t>(1);
       }
       for (int arg = 0; arg < NARGS; arg++) {
         strides_[i][arg] =  i < dims ? strides[arg][i] : 0;
@@ -30,7 +30,7 @@ struct OffsetCalculator {
     }
   }
 
-  C10_HOST_DEVICE offset_type get(uint32_t linear_idx) const {
+  C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
     offset_type offsets;
     #pragma unroll
     for (int arg = 0; arg < NARGS; arg++) {
@@ -55,6 +55,6 @@ struct OffsetCalculator {
   }
 
   int dims;
-  IntDivider<uint32_t> sizes_[MAX_DIMS];
-  uint32_t strides_[MAX_DIMS][NARGS];
+  IntDivider<index_t> sizes_[MAX_DIMS];
+  index_t strides_[MAX_DIMS][NARGS];
 };
index 077be07..a8f84e6 100644 (file)
@@ -4,13 +4,37 @@
 
 // Modified from https://stackoverflow.com/questions/7943525/is-it-possible-to-figure-out-the-parameter-type-and-return-type-of-a-lambda
 
-// For generic types, directly use the result of the signature of its 'operator()'
+// Fallback, anything with an operator()
 template <typename T>
 struct function_traits : public function_traits<decltype(&T::operator())> {
 };
 
+// Pointers to class members that are themselves functors.
+// For example, in the following code:
+// template <typename func_t>
+// struct S {
+//     func_t f;
+// };
+// template <typename func_t>
+// S<func_t> make_s(func_t f) {
+//     return S<func_t> { .f = f };
+// }
+//
+// auto s = make_s([] (int, float) -> double { /* ... */ });
+//
+// function_traits<decltype(&s::f)> traits;
+template <typename ClassType, typename T>
+struct function_traits<T ClassType::*> : public function_traits<T> {
+};
+
+// Const class member functions
 template <typename ClassType, typename ReturnType, typename... Args>
-struct function_traits<ReturnType(ClassType::*)(Args...) const> {
+struct function_traits<ReturnType(ClassType::*)(Args...) const> : public function_traits<ReturnType(Args...)> {
+};
+
+// Free functions
+template <typename ReturnType, typename... Args>
+struct function_traits<ReturnType(Args...)> {
   // arity is the number of arguments.
   enum { arity = sizeof...(Args) };
 
index 4282d37..2bf420f 100644 (file)
@@ -524,19 +524,13 @@ Tensor &std_out(Tensor &result, const Tensor &self, IntList dim, bool unbiased,
   AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
            "std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
   AT_CHECK(at::isFloatingType(self.type().scalarType()), "std only supports floating-point dtypes");
-  if (self.type().backend() != Backend::CPU) {
-    // TODO(btv): implement multi-dim `std` and `var` on CUDA.
-    AT_CHECK(dim.size() == 1, "`std` across arbitrarily many dimensions is not yet supported for CUDA.")
-    int64_t one_dim = maybe_wrap_dim(dim[0], self.dim());
-    if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), one_dim, keepdim)) {
-      return result;
-    } else {
-      return at::legacy::th::_th_std_out(result, self, one_dim, unbiased, keepdim);
-    }
-  }
   ScalarType dtype = get_dtype(result, self, {}, true);
   auto iter = make_reduction("std", result, self, dim, keepdim, dtype);
-  std_stub(iter->device_type(), *iter, unbiased);
+  if (iter->numel() == 0) {
+    result.fill_(NAN);
+  } else {
+    std_stub(iter->device_type(), *iter, unbiased);
+  }
   return result;
 }
 
diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h
new file mode 100644 (file)
index 0000000..1c75eaa
--- /dev/null
@@ -0,0 +1,104 @@
+#pragma once
+// Please note that this file is
+// used across both CPU and GPU.
+
+#include <c10/macros/Macros.h>
+#if defined(__CUDACC__)
+#include <THC/THCDeviceUtils.cuh>
+#include <ATen/native/cuda/DeviceSqrt.cuh>
+#elif defined(__HIPCC__)
+#include <THH/THHDeviceUtils.cuh>
+#include <ATen/native/hip/DeviceSqrt.cuh>
+#else
+#include <cmath>
+#define device_sqrt std::sqrt
+#endif
+
+namespace at { namespace native {
+
+template <typename scalar_t>
+struct WelfordData {
+  scalar_t mean;
+  scalar_t m2;
+  int64_t n;
+  C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0)  {}
+  C10_DEVICE WelfordData(scalar_t mean, scalar_t m2, int64_t n) : mean(mean), m2(m2), n(n) {}
+};
+
+
+template <typename scalar_t, typename acc_scalar_t>
+struct WelfordOps {
+  bool unbiased;
+ public:
+  using acc_t = WelfordData<acc_scalar_t>;
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data) const {
+    acc_scalar_t delta = data - acc.mean;
+    acc_scalar_t new_mean = acc.mean + delta / (acc.n + 1);
+    acc_scalar_t new_delta = data - new_mean;
+    return {
+      new_mean,
+      acc.m2 + delta * new_delta,
+      acc.n + 1
+    };
+  }
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    if (a.n == 0) {
+      return b;
+    }
+    if (b.n == 0) {
+      return a;
+    }
+    acc_scalar_t delta = b.mean - a.mean;
+    int64_t new_count = a.n + b.n;
+    acc_scalar_t nb_over_n = (scalar_t)b.n / new_count;
+    return {
+      a.mean + delta * nb_over_n,
+      a.m2 + b.m2 + delta * delta * a.n * nb_over_n,
+      new_count
+    };
+  }
+  inline C10_DEVICE scalar_t project(acc_t acc) const {
+    int64_t divisor = unbiased ? (acc.n - 1) : acc.n;
+    return (divisor > 0) ? (scalar_t)device_sqrt(acc.m2 / divisor) : (scalar_t)NAN;
+  }
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return {
+      WARP_SHFL_DOWN(acc.mean, offset)
+      , WARP_SHFL_DOWN(acc.m2, offset)
+      , WARP_SHFL_DOWN(acc.n, offset)
+    };
+  }
+#endif
+  WelfordOps(bool unbiased) : unbiased(unbiased) {
+  }
+};
+
+template <typename acc_t, typename factor_t>
+struct MeanOps {
+  factor_t factor;
+
+  inline C10_DEVICE acc_t reduce(acc_t a, acc_t b) const {
+    return a + b;
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return reduce(a, b);
+  }
+
+  inline C10_DEVICE acc_t project(acc_t a) const {
+    return a * factor;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+    return WARP_SHFL_DOWN(data, offset);
+  }
+#endif
+
+  MeanOps(factor_t factor): factor(factor) {
+  }
+};
+
+
+}} // namespace at::native
index 9e22d5e..3cab114 100644 (file)
@@ -34,7 +34,8 @@ struct all_same : c10::guts::conjunction<
 // acc_t is a type that contains all the necessary data
 // to continue reducing.
 //
-// Then:
+// ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy
+// the following.
 // reduce: (acc_t, data_t) -> acc_t adds one data point to the accumulated value.
 // combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one.
 // project: acc_t -> data_t finishes the reduction, getting the required output.
@@ -54,11 +55,11 @@ struct all_same : c10::guts::conjunction<
 // If, on the other hand, there is only one, then we split the input into
 // into several pieces, reduce each separately, and then combine them.
 
-template <typename rf_t,
-          typename cf_t,
-          typename pf_t,
-          typename init_t>
-void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &combine, pf_t const &project, init_t init) {
+template <typename ops_t, typename init_t>
+void binary_kernel_reduce(TensorIterator& iter, ops_t ops, init_t init) {
+  using rf_t = decltype(&ops_t::reduce);
+  using cf_t = decltype(&ops_t::combine);
+  using pf_t = decltype(&ops_t::project);
   using r_traits = binary_function_traits<rf_t>;
   using c_traits = binary_function_traits<cf_t>;
   using p_traits = unary_function_traits<pf_t>;
@@ -90,7 +91,7 @@ void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &
     at::parallel_for(0, numel, serial ? (1 + numel) : internal::GRAIN_SIZE,
     [&](int64_t begin, int64_t end) {
       auto &acc = buffer[at::get_thread_num()];
-      sub_iter.serial_for_each([&acc, &reduce, &init](int ntensors, char** data, const int64_t* strides, int64_t size) {
+      sub_iter.serial_for_each([&acc, &ops, &init](int ntensors, char** data, const int64_t* strides, int64_t size) {
         AT_ASSERT(ntensors == 2);
         char *in = data[1];
         int64_t stride = strides[1];
@@ -99,7 +100,7 @@ void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &
           acc = init;
         }
         for (int64_t i = 0; i < size; ++i) {
-          acc = reduce(*acc, *(data_t*)in);
+          acc = ops.reduce(*acc, *(data_t*)in);
           in += stride;
         }
       }, {begin, end});
@@ -107,11 +108,11 @@ void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &
     acc_t acc = init;
     for (int i = 0; i < max_threads; ++i) {
       if (buffer[i]) {
-        acc = combine(acc, *buffer[i]);
+        acc = ops.combine(acc, *buffer[i]);
       }
     }
     char *out = (char *)sub_iter.data_ptr(0);
-    *(data_t*)out = project(acc);
+    *(data_t*)out = ops.project(acc);
   });
 }
 
index 0881912..a8238a8 100644 (file)
@@ -6,6 +6,7 @@
 #include <ATen/cpu/vec256/vec256.h>
 #include <ATen/native/ReduceOps.h>
 #include <ATen/native/TensorIterator.h>
+#include <ATen/native/SharedReduceOps.h>
 #include <ATen/native/cpu/Reduce.h>
 #include <c10/util/Optional.h>
 
@@ -27,55 +28,18 @@ static void mean_kernel_impl(TensorIterator& iter) {
     scalar_t factor = scalar_t(iter.num_output_elements()) / iter.numel();
     binary_kernel_reduce(
       iter,
-      [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
-      [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
-      [factor](scalar_t a) -> scalar_t { return a*factor; }, scalar_t(0));
+      MeanOps<scalar_t, scalar_t> {factor},
+      scalar_t(0)
+    );
   });
 }
 
-struct WelfordData {
-  double mean;
-  double m2;
-  int64_t n;
-  WelfordData() : mean(0), m2(0), n(0)  {}
-  WelfordData(double mean, double m2, int64_t n) : mean(mean), m2(m2), n(n) {}
-};
-
 static void std_kernel_impl(TensorIterator &iter, bool unbiased) {
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&] {
     binary_kernel_reduce(
       iter,
-      [](WelfordData acc, scalar_t data) -> WelfordData {
-        double delta = data - acc.mean;
-        double new_mean = acc.mean + delta / (acc.n + 1);
-        double new_delta = data - new_mean;
-        return {
-          new_mean,
-          acc.m2 + delta * new_delta,
-          acc.n + 1
-        };
-      },
-      [](WelfordData a, WelfordData b) -> WelfordData {
-        if (a.n == 0) {
-          return b;
-        }
-        if (b.n == 0) {
-          return a;
-        }
-        double delta = b.mean - a.mean;
-        int64_t new_count = a.n + b.n;
-        double nb_over_n = (double)b.n / new_count;
-        return {
-          a.mean + delta * nb_over_n,
-          a.m2 + b.m2 + delta * delta * a.n * nb_over_n,
-          new_count
-        };
-      },
-      [unbiased](WelfordData acc) -> scalar_t {
-        int64_t divisor = unbiased ? (acc.n - 1) : acc.n;
-        return (divisor > 0) ? std::sqrt(acc.m2 / divisor) : NAN;
-      },
-      WelfordData()
+      WelfordOps<scalar_t, double> { unbiased },
+      WelfordData<double>()
     );
   });
 }
diff --git a/aten/src/ATen/native/cuda/DeviceSqrt.cuh b/aten/src/ATen/native/cuda/DeviceSqrt.cuh
new file mode 100644 (file)
index 0000000..29711a0
--- /dev/null
@@ -0,0 +1,25 @@
+#pragma once
+
+namespace at { namespace native {
+#if defined(__HIP_PLATFORM_HCC__)
+// take these out when ROCm implements std:: math functions
+#include <math.h>
+template <typename scalar_t>
+static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
+
+template <>
+__forceinline__ __device__ float device_sqrt(float val) {
+  return ::sqrtf(val);
+}
+
+template <>
+__forceinline__ __device__ double device_sqrt(double val) {
+  return ::sqrt(val);
+}
+#else
+template<typename scalar_t>
+__forceinline__ __device__ double device_sqrt(scalar_t val) {
+  return std::sqrt(val);
+}
+#endif
+}}
index c4d10f7..e186ef3 100644 (file)
@@ -6,6 +6,7 @@
 #include <ATen/AccumulateType.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/native/cuda/DeviceSqrt.cuh>
 
 #include <ATen/native/cuda/Normalization.cuh>
 
@@ -13,29 +14,8 @@ namespace at { namespace native {
 
 #if defined(__HIP_PLATFORM_HCC__)
 constexpr int WARP_SIZE = 64;
-
-// take these out when ROCm implements std:: math functions
-#include <math.h>
-template <typename scalar_t>
-static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
-
-template <>
-__forceinline__ __device__ float device_sqrt(float val) {
-  return ::sqrtf(val);
-}
-
-template <>
-__forceinline__ __device__ double device_sqrt(double val) {
-  return ::sqrt(val);
-}
-
 #else
 constexpr int WARP_SIZE = 32;
-
-template<typename scalar_t>
-__forceinline__ __device__ double device_sqrt(scalar_t val) {
-  return std::sqrt(val);
-}
 #endif
 
 // The maximum number of threads in a block
index 989fd26..6e870ad 100644 (file)
@@ -1,5 +1,6 @@
 #pragma once
 
+#include <assert.h>
 #include <ATen/ATen.h>
 #include <ATen/cuda/Array.h>
 #include <ATen/cuda/CUDAContext.h>
@@ -9,7 +10,11 @@
 #include <THC/THCGeneral.hpp>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cuda/Loops.cuh>
+#include <functional>
 #include <iosfwd>
+#include <tuple>
+#include <type_traits>
+#include <utility>
 
 namespace at { namespace native {
 
@@ -115,11 +120,11 @@ struct ReduceConfig {
     return element_size_bytes * NUM_THREADS;
   }
 
-  int global_memory_size() const {
+  int64_t global_memory_size() const {
     if (!should_global_reduce()) {
       return 0;
     }
-    int size = element_size_bytes * num_outputs * ctas_per_output;
+    auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
     if (!should_warp_reduce()) {
       size *= block().x;
     }
@@ -146,7 +151,8 @@ __global__ void reduce_kernel(R reduction) {
   reduction.run();
 }
 
-static OffsetCalculator<2> make_output_calculator(const TensorIterator& iter) {
+template <typename index_t>
+static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) {
   int num_reduce_dims = iter.num_reduce_dims();
   int num_output_dims = iter.ndim() - num_reduce_dims;
   std::array<const int64_t*, 2> strides = {
@@ -154,28 +160,29 @@ static OffsetCalculator<2> make_output_calculator(const TensorIterator& iter) {
     iter.strides(1).data() + num_reduce_dims,
   };
   auto shape = iter.shape().data() + num_reduce_dims;
-  return OffsetCalculator<2>(num_output_dims, shape, strides.data());
+  return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data());
 }
 
-static OffsetCalculator<1> make_input_calculator(const TensorIterator& iter) {
+template <typename index_t>
+static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) {
   int num_reduce_dims = iter.num_reduce_dims();
   std::array<const int64_t*, 1> strides = {
     iter.strides(1).data(),
   };
-  return OffsetCalculator<1>(num_reduce_dims, iter.shape().data(), strides.data());
+  return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data());
 }
 
-template <int vt, typename func_t>
-__device__ void strided_iterate(func_t f, int begin, int end, int stride) {
+template <int vt, typename index_t, typename func_t>
+__device__ void strided_iterate(func_t f, index_t begin, index_t end, index_t stride) {
   if (begin + (vt - 1) * stride < end) {
     #pragma unroll
-    for (int i = 0; i < vt; i++) {
+    for (index_t i = 0; i < vt; i++) {
       f(i, begin + i * stride);
     }
   } else {
     #pragma unroll
-    for (int i = 0; i < vt; i++) {
-      int idx = begin + i * stride;
+    for (index_t i = 0; i < vt; i++) {
+      index_t idx = begin + i * stride;
       if (idx < end) {
         f(i, idx);
       }
@@ -183,34 +190,56 @@ __device__ void strided_iterate(func_t f, int begin, int end, int stride) {
   }
 }
 
-template <int vt, typename type_t, typename foo_t>
-__device__ Array<type_t, vt> load_memory(const type_t* in, int begin, int end, int stride, foo_t foo) {
+template <int vt, typename index_t, typename type_t, typename foo_t>
+__device__ Array<type_t, vt> load_memory(const type_t* in, index_t begin, index_t end, index_t stride, foo_t foo) {
   Array<type_t, vt> res;
-  strided_iterate<vt>([&](int i, int idx) {
+  strided_iterate<vt>([&](index_t i, index_t idx) {
     res[i] = in[foo(idx)];
   }, begin, end, stride);
   return res;
 }
 
-template <int vt, typename type_t>
-__device__ Array<type_t, vt> load_memory(const type_t* in, int begin, int end, int stride) {
-  return load_memory<vt>(in, begin, end, stride, [](int idx) { return idx; });
+template <int vt, typename index_t, typename type_t>
+__device__ Array<type_t, vt> load_memory(const type_t* in, index_t begin, index_t end, index_t stride) {
+  return load_memory<vt, index_t>(in, begin, end, stride, [](index_t idx) { return idx; });
 }
 
-template <typename scalar_t, typename func_t, typename pre_func_t,
-          typename post_func_t, typename out_scalar_t=scalar_t>
+template <typename out_scalar_t, typename func_t>
+struct func_wrapper_t {
+  using arg_t = typename binary_function_traits<func_t>::arg2_t;
+  func_t reduce;
+  func_t combine;
+  static inline __device__ out_scalar_t project(arg_t arg) {
+    return (out_scalar_t) arg;
+  }
+  static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
+    return WARP_SHFL_DOWN(arg, offset);
+  }
+
+  func_wrapper_t(const func_t& op) : reduce(op), combine(op) {
+  }
+};
+
+template <typename scalar_t, typename func_t>
+func_wrapper_t<scalar_t, func_t> func_wrapper(const func_t& op) {
+  using arg_t = typename binary_function_traits<func_t>::arg2_t;
+  return func_wrapper_t<scalar_t, func_t> { op };
+}
+
+template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t>
 struct ReduceOp {
-  using traits = binary_function_traits<func_t>;
-  using arg_t = typename traits::arg2_t;
+  using traits = binary_function_traits<decltype(&ops_t::reduce)>;
+  using arg_t = typename std::remove_const<typename std::remove_reference<typename traits::arg1_t>::type>::type;
 
-  using InputCalculator = OffsetCalculator<1>;
-  using OutputCalculator = OffsetCalculator<2>;
+  using InputCalculator = OffsetCalculator<1, index_t>;
+  using OutputCalculator = OffsetCalculator<2, index_t>;
 
   static constexpr int vt0 = 4;
+  static constexpr bool can_accumulate_in_output =
+    std::is_convertible<arg_t, out_scalar_t>::value;
+
 
-  func_t op;
-  pre_func_t pre_op;
-  post_func_t post_op;
+  ops_t ops;
   arg_t ident;
   ReduceConfig config;
   InputCalculator input_calc;
@@ -221,24 +250,22 @@ struct ReduceOp {
   int* semaphores;
   bool accumulate;
 
-  ReduceOp(func_t op, ReduceConfig config, InputCalculator input_calc, OutputCalculator output_calc,
-           const void* src, void* dst, void* buffer, int* semaphores, pre_func_t pre_op,
-           post_func_t post_op)
-    : op(op)
-    , pre_op(pre_op)
-    , post_op(post_op)
+  ReduceOp(ops_t ops, ReduceConfig config, InputCalculator input_calc, OutputCalculator output_calc,
+           const void* src, void* dst, void* buffer, int* semaphores, arg_t ident)
+    : ops(ops)
     , config(config)
     , input_calc(input_calc)
     , output_calc(output_calc)
     , src(src)
     , dst(dst)
     , buffer(buffer)
-    , semaphores(semaphores) {
+    , semaphores(semaphores)
+    , ident(ident) {
   }
 
   C10_DEVICE void run() const {
-    int output_idx = config.output_idx();
-    int input_idx = config.input_idx();
+    index_t output_idx = config.output_idx();
+    index_t input_idx = config.input_idx();
     auto base_offsets = output_calc.get(output_idx);
 
     arg_t value = ident;
@@ -258,34 +285,33 @@ struct ReduceOp {
     if (config.should_global_reduce()) {
       value = global_reduce(value, out);
     } else if (config.should_store(output_idx)) {
-      value = post_op(value);
       if (accumulate) {
-        value = op(*out, value);
+        value = accumulate_in_output<can_accumulate_in_output>(out, value);
       }
-      *out = value;
+      *out = ops.project(value);
     }
   }
 
-  C10_DEVICE Array<scalar_t, vt0> load_inputs(const scalar_t* data, int offset) const {
-    int end = config.num_inputs;
-    int stride = input_calc.strides_[0][0] / sizeof(scalar_t);
+  C10_DEVICE Array<scalar_t, vt0> load_inputs(const scalar_t* data, index_t offset) const {
+    index_t end = config.num_inputs;
+    index_t stride = input_calc.strides_[0][0] / sizeof(scalar_t);
     if (input_calc.dims == 1) {
-      return load_memory<vt0>(data, offset, end, config.step_input, [&](int idx) {
+      return load_memory<vt0, index_t>(data, offset, end, config.step_input, [&](index_t idx) {
         return idx * stride;
       });
     } else {
-      return load_memory<vt0>(data, offset, end, config.step_input, [&](int idx) {
+      return load_memory<vt0, index_t>(data, offset, end, config.step_input, [&](index_t idx) {
         return input_calc.get(idx)[0] / sizeof(scalar_t);
       });
     }
   }
 
-  C10_DEVICE arg_t thread_reduce_once(const scalar_t* data, int offset) const {
+  C10_DEVICE arg_t thread_reduce_once(const scalar_t* data, index_t offset) const {
     auto values = load_inputs(data, offset);
 
-    arg_t value;
-    strided_iterate<vt0>([&](int i, int idx) {
-      value = i == 0 ? pre_op(values[0]) : op(value, pre_op(values[i]));
+    arg_t value = ident;
+    strided_iterate<vt0, index_t>([&](index_t i, index_t idx) {
+      value = ops.reduce(value, values[i]);
     }, offset, config.num_inputs, config.step_input);
 
     return value;
@@ -293,10 +319,10 @@ struct ReduceOp {
 
   C10_DEVICE arg_t thread_reduce(const scalar_t* data) const {
     arg_t value = ident;
-    int idx = config.input_idx();
+    index_t idx = config.input_idx();
     while (idx < config.num_inputs) {
       arg_t next = thread_reduce_once(data, idx);
-      value = op(value, next);
+      value = ops.combine(value, next);
       idx += config.step_input * vt0;
     }
     return value;
@@ -304,8 +330,8 @@ struct ReduceOp {
 
   C10_DEVICE arg_t warp_reduce(arg_t value) const {
     for (int offset = 1; offset < warpSize; offset <<= 1) {
-      arg_t other = WARP_SHFL_DOWN(value, offset);
-      value = op(value, other);
+      arg_t other = ops.warp_shfl_down(value, offset);
+      value = ops.combine(value, other);
     }
     return value;
   }
@@ -319,7 +345,7 @@ struct ReduceOp {
       __syncthreads();
       if (threadIdx.y < offset && threadIdx.y + offset < num_warps) {
         arg_t other = shared[config.shared_memory_offset(offset)];
-        value = op(value, other);
+        value = ops.combine(value, other);
         shared[config.shared_memory_offset(0)] = value;
       }
     }
@@ -341,13 +367,33 @@ struct ReduceOp {
 
     return is_last_block_done;
   }
+  
+  template <bool can_acc>
+  C10_DEVICE arg_t accumulate_in_output(
+    out_scalar_t* out, arg_t value,
+    typename std::enable_if<can_acc>::type* = nullptr
+  ) const {
+    return ops.reduce(*out, value);
+  }
+
+  // This function should never be called --
+  // it's the version of `accumulate_in_output`
+  // when accumulation in the output is not possible.
+  template <bool can_acc>
+  C10_DEVICE arg_t accumulate_in_output(
+    out_scalar_t*, arg_t,
+    typename std::enable_if<!can_acc>::type* = nullptr
+  ) const {
+    assert(false); // can't use AT_ASSERT in Cuda.
+    return arg_t {};
+  }
 
   C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out) const {
     arg_t* reduce_buffer = (arg_t*)buffer;
 
     bool should_store = config.should_store(config.output_idx());
     if (should_store) {
-      int offset = config.staging_memory_offset(blockIdx.y);
+      index_t offset = config.staging_memory_offset(blockIdx.y);
       reduce_buffer[offset] = value;
     }
 
@@ -356,34 +402,33 @@ struct ReduceOp {
     bool is_last_block_done = mark_block_finished();
 
     if (is_last_block_done) {
-      value = 0;
+      value = arg_t {};
       if (config.should_warp_reduce()) {
-        int input_offset = threadIdx.x + threadIdx.y * blockDim.x;
-        int step = blockDim.x * blockDim.y;
+        index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
+        index_t step = blockDim.x * blockDim.y;
         for (; input_offset < config.ctas_per_output; input_offset += step) {
-          int idx = config.staging_memory_offset(input_offset);
+          index_t idx = config.staging_memory_offset(input_offset);
           arg_t next = reduce_buffer[idx];
-          value = op(value, next);
+          value = ops.combine(value, next);
         }
       } else {
-        int input_offset = threadIdx.y;
-        int step = blockDim.y;
+        index_t input_offset = threadIdx.y;
+        index_t step = blockDim.y;
         for (; input_offset < config.ctas_per_output; input_offset += step) {
-          int idx = config.staging_memory_offset(input_offset);
+          index_t idx = config.staging_memory_offset(input_offset);
           arg_t next = reduce_buffer[idx];
-          value = op(value, next);
+          value = ops.combine(value, next);
         }
       }
       value = block_reduce(value);
       if (config.should_warp_reduce()) {
         value = warp_reduce(value);
       }
-      value = post_op(value);
       if (should_store) {
         if (accumulate) {
-          value = op(*out, value);
+          value = accumulate_in_output<can_accumulate_in_output>(out, value);
         }
-        *out = value;
+        *out = ops.project(value);
       }
     }
 
@@ -401,17 +446,19 @@ static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction)
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
-template <typename scalar_t, typename out_scalar_t, typename func_t, typename pre_func_t,
-          typename post_func_t, typename ident_t=double>
-inline void gpu_reduce_kernel(TensorIterator& iter, const pre_func_t &pre_op,
-                              const post_func_t &post_op, const func_t& op,
-                              ident_t ident=0) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+template <typename scalar_t, typename out_scalar_t, typename ops_t, typename ident_t=double>
+inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0) {
   AT_ASSERT(iter.numel() > 0 && iter.ntensors() == 2);
 
-  if (!iter.can_use_32bit_indexing()) {
+  using traits = binary_function_traits<decltype(&ops_t::reduce)>;
+  using arg_t = typename traits::arg1_t;
+  static constexpr bool can_accumulate_in_output =
+    std::is_convertible<arg_t, out_scalar_t>::value;
+
+  bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
+  if (can_accumulate_in_output && !can_use_32bit_indexing) {
     for (auto& sub_iter : iter.with_32bit_indexing()) {
-      gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, pre_op, post_op, op);
+      gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, ops, ident);
     }
     return;
   }
@@ -419,8 +466,6 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const pre_func_t &pre_op,
   char* out_data = (char*)iter.data_ptr(0);
   const char* in_data = (char*)iter.data_ptr(1);
 
-  using traits = binary_function_traits<func_t>;
-  using arg_t = typename traits::arg2_t;
 
   int warp_size = at::cuda::warp_size();
   int warps_per_cta = ReduceConfig::NUM_THREADS / warp_size;
@@ -463,9 +508,6 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const pre_func_t &pre_op,
     config.input_mult[2] = config.split_input(config.ctas_per_output);
   }
 
-  auto output_calc = make_output_calculator(iter);
-  auto input_calc = make_input_calculator(iter);
-
   at::DataPtr buffer;
   at::DataPtr semaphores;
   if (config.should_global_reduce()) {
@@ -476,21 +518,41 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const pre_func_t &pre_op,
     auto stream = at::cuda::getCurrentCUDAStream();
     AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
   }
-  auto reduce = ReduceOp<scalar_t, func_t, pre_func_t, post_func_t, out_scalar_t>(
-      op,
-      config,
-      input_calc,
-      output_calc,
-      in_data,
-      out_data,
-      buffer.get(),
-      (int*)semaphores.get(),
-      pre_op,
-      post_op);
-  reduce.ident = ident;
-  reduce.accumulate = iter.should_accumulate();
-
-  launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
+
+  if (can_use_32bit_indexing) {
+    auto output_calc = make_output_calculator<uint32_t>(iter);
+    auto input_calc = make_input_calculator<uint32_t>(iter);
+    auto reduce = ReduceOp<scalar_t, ops_t, uint32_t, out_scalar_t>(
+        ops,
+        config,
+        input_calc,
+        output_calc,
+        in_data,
+        out_data,
+        buffer.get(),
+        (int*)semaphores.get(),
+        ident);
+    reduce.accumulate = iter.should_accumulate();
+
+    launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
+  } else {
+    auto output_calc = make_output_calculator<uint64_t>(iter);
+    auto input_calc = make_input_calculator<uint64_t>(iter);
+    auto reduce = ReduceOp<scalar_t, ops_t, uint64_t, out_scalar_t>(
+        ops,
+        config,
+        input_calc,
+        output_calc,
+        in_data,
+        out_data,
+        buffer.get(),
+        (int*)semaphores.get(),
+        ident);
+    AT_ASSERT(!iter.should_accumulate());
+    reduce.accumulate = false;
+
+    launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
+  }
 }
 
 }} // namespace at::native
index 283fc57..d1d9dc1 100644 (file)
@@ -1,12 +1,15 @@
+#include <ATen/native/SharedReduceOps.h>
 #include <ATen/AccumulateType.h>
 #include <ATen/Context.h>
 #include <ATen/Dispatch.h>
+#include <ATen/native/cuda/DeviceSqrt.cuh>
 #include <ATen/native/cuda/Loops.cuh>
 #include <ATen/native/cuda/Reduce.cuh>
 #include <ATen/native/DispatchStub.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/ReduceOps.h>
 #include <limits>
+#include <tuple>
 
 
 namespace at { namespace native {
@@ -24,10 +27,19 @@ struct SimpleCopy {
 
 template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
 void sum_kernel_impl(TensorIterator& iter) {
-  gpu_reduce_kernel<scalar_t, out_t>(iter, SimpleCopy<acc_t>(), SimpleCopy<acc_t>(),
-                                     []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+  gpu_reduce_kernel<scalar_t, out_t>(iter, func_wrapper<out_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
     return a + b;
-  });
+  }));
+}
+
+template <typename scalar_t>
+void std_kernel_impl(TensorIterator& iter, bool unbiased) {
+  gpu_reduce_kernel<scalar_t, scalar_t>(iter, WelfordOps<scalar_t, scalar_t> { unbiased }, WelfordData<scalar_t> {});
+}
+
+template <>
+void std_kernel_impl<at::Half>(TensorIterator& iter, bool unbiased) {
+  gpu_reduce_kernel<at::Half, at::Half>(iter, WelfordOps<at::Half, float> { unbiased }, WelfordData<float> {});
 }
 
 #ifdef __HIPCC__
@@ -37,27 +49,29 @@ void sum_kernel_impl<int16_t, int16_t>(TensorIterator& iter) {
   // compiler segfaults:
   // https://bugs.llvm.org/show_bug.cgi?id=39602
   // To work around it, use int32 as the accumulate type.
-  gpu_reduce_kernel<int16_t, int16_t>(iter, SimpleCopy<int32_t>(), SimpleCopy<int32_t>(),
-                                      []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
+  gpu_reduce_kernel<int16_t, int16_t>(iter, func_wrapper<int16_t> ([]GPU_LAMBDA(int32_t a, int32_t b) -> int32_t {
     return a + b;
-  });
+  }));
 }
 #endif
 
 template <typename scalar_t, typename acc_t=scalar_t>
 void prod_kernel_impl(TensorIterator& iter) {
-  gpu_reduce_kernel<scalar_t, scalar_t>(iter, SimpleCopy<acc_t>(), SimpleCopy<acc_t>(),
-                                        []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
+  gpu_reduce_kernel<scalar_t, scalar_t>(iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
     return a * b;
-  }, 1);
+  }), 1);
+}
+
+static void std_kernel_cuda(TensorIterator& iter, bool unbiased) {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&]() {
+    std_kernel_impl<scalar_t>(iter, unbiased);
+  });
 }
 
 template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
 void mean_kernel_impl(TensorIterator& iter) {
   float factor = float(iter.num_output_elements()) / iter.numel();
-  gpu_reduce_kernel<scalar_t, out_t>(iter, SimpleCopy<acc_t>(), 
-      [factor]GPU_LAMBDA(acc_t a) -> acc_t { return a*factor; },
-      []GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { return a + b; });
+  gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<acc_t, float> {factor});
 }
 
 #ifdef __HIPCC__
@@ -68,9 +82,7 @@ void mean_kernel_impl<int16_t, int16_t, int16_t>(TensorIterator& iter) {
   // https://bugs.llvm.org/show_bug.cgi?id=39602
   // To work around it, use int32 as the accumulate type.
   float factor = float(iter.num_output_elements()) / iter.numel();
-  gpu_reduce_kernel<int16_t, int16_t>(iter, SimpleCopy<int32_t>(),
-      [factor]GPU_LAMBDA(int32_t a) -> int32_t { return a*factor; },
-      []GPU_LAMBDA(int32_t a, int32_t b) -> int32_t { return a + b; });
+  gpu_reduce_kernel<int16_t, int16_t>(iter, MeanOps<int32_t, float> {factor});
 }
 #endif // __HIPCC__
 
@@ -107,6 +119,7 @@ static void mean_kernel_cuda(TensorIterator& iter) {
   });
 }
 
+REGISTER_DISPATCH(std_stub, &std_kernel_cuda);
 REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
 REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
 REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
index 216ea85..d2718b5 100644 (file)
@@ -1943,6 +1943,10 @@ class _TestTorchMixin(object):
                         expected = numpy_op(tensor.numpy(), dim)
                     actual = pytorch_op(tensor, dim)
                     self._assert_matches_numpy(actual, expected)
+                    if torch.cuda.is_available():
+                        self._assert_matches_numpy(pytorch_op(tensor.cuda(),
+                                                              dim).cpu(),
+                                                   expected)
         do_one(self._make_tensors((5, 400000), use_floating=use_floating,
                use_integral=use_integral), 1)
         do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
index fa05483..55fde66 100644 (file)
@@ -2722,7 +2722,7 @@ output tensor having 1 (or ``len(dim)``) fewer dimension(s).
 
 Args:
     input (Tensor): the input tensor
-    dim (int): the dimension to reduce
+    dim (int or tuple of ints): the dimension or dimensions to reduce
     keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not
     out (Tensor): the output tensor
 
@@ -4443,19 +4443,20 @@ Example::
 .. function:: std(input, dim, keepdim=False, unbiased=True, out=None) -> Tensor
 
 Returns the standard-deviation of each row of the :attr:`input` tensor in the
-given dimension :attr:`dim`.
+dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
+reduce over all of them.
 
-If :attr:`keepdim` is ``True``, the output tensor is of the same size as
-:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
-Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting
-in the output tensor having 1 fewer dimension than :attr:`input`.
+If :attr:`keepdim` is ``True``, the output tensor is of the same size
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
+Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
+output tensor having 1 (or ``len(dim)``) fewer dimension(s).
 
 If :attr:`unbiased` is ``False``, then the standard-deviation will be calculated
 via the biased estimator. Otherwise, Bessel's correction will be used.
 
 Args:
     input (Tensor): the input tensor
-    dim (int): the dimension to reduce
+    dim (int or tuple of ints): the dimension or dimensions to reduce
     keepdim (bool): whether the output tensor has :attr:`dim` retained or not
     unbiased (bool): whether to use the unbiased estimation or not
     out (Tensor, optional): the output tensor