Separate Moments from math and optimize it (#16175)
authorXiaomeng Yang <yangxm@fb.com>
Sun, 20 Jan 2019 16:50:32 +0000 (08:50 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 20 Jan 2019 16:53:25 +0000 (08:53 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16175

Separate Moments from math and optimize it

i-am-not-moving-c2-to-c10

Reviewed By: houseroad

Differential Revision: D13742472

fbshipit-source-id: 90757d908d38c98ca69818855aaf68315e525992

17 files changed:
caffe2/operators/group_norm_op.h
caffe2/operators/layer_norm_op.h
caffe2/operators/moments_op.h
caffe2/operators/spatial_batch_norm_op.h
caffe2/quantization/server/group_norm_dnnlowp_op.cc
caffe2/utils/CMakeLists.txt
caffe2/utils/fixed_divisor.h
caffe2/utils/math.h
caffe2/utils/math/reduce.cc [new file with mode: 0644]
caffe2/utils/math/reduce.cu [new file with mode: 0644]
caffe2/utils/math/reduce.h [new file with mode: 0644]
caffe2/utils/math_cpu.cc
caffe2/utils/math_gpu.cu
caffe2/utils/math_gpu_test.cc
caffe2/utils/math_test.cc
caffe2/utils/math_utils.cc
caffe2/utils/math_utils.h

index 16bbe64..0af276f 100644 (file)
@@ -57,8 +57,10 @@ class GroupNormOp final : public Operator<Context> {
       mu_data = mu->template mutable_data<T>();
       rsig_data = rsig->template mutable_data<T>();
     } else {
-      ReinitializeTensor(&mu_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
-      ReinitializeTensor(&rsig_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
+      ReinitializeTensor(
+          &mu_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
+      ReinitializeTensor(
+          &rsig_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
       mu_data = mu_.template mutable_data<T>();
       rsig_data = rsig_.template mutable_data<T>();
     }
@@ -88,24 +90,26 @@ class GroupNormOp final : public Operator<Context> {
       T* mu,
       T* rsig) {
     const int C = G * D;
-    ReinitializeTensor(&scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
-    ReinitializeTensor(&bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
+    ReinitializeTensor(
+        &scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
+    ReinitializeTensor(
+        &bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
     T* scale_data = scale_.template mutable_data<T>();
     T* bias_data = bias_.template mutable_data<T>();
     if (order_ == StorageOrder::NCHW) {
-      const std::array<int, 2> dims = {N * G, D * HxW};
-      const int axis = 1;
+      const std::array<int, 2> X_dims = {N * G, D * HxW};
+      const std::array<int, 2> Y_dims = {N * G, 1};
       math::Moments<T, Context>(
-          2, dims.data(), 1, &axis, X, mu, rsig, &context_);
+          2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
       math::InvStd<T, Context>(
           N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
       ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
       GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
     } else {
-      const std::array<int, 4> dims = {N, HxW, G, D};
-      const std::array<int, 2> axes = {1, 3};
+      const std::array<int, 4> X_dims = {N, HxW, G, D};
+      const std::array<int, 4> Y_dims = {N, 1, G, 1};
       math::Moments<T, Context>(
-          4, dims.data(), 2, axes.data(), X, mu, rsig, &context_);
+          4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
       math::InvStd<T, Context>(
           N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
       ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
index 94305c0..bddb182 100644 (file)
@@ -37,22 +37,22 @@ class LayerNormOp final : public Operator<Context> {
     moments_dims.push_back(1);
     auto* mean = Output(1, moments_dims, at::dtype<T>());
     auto* sig = Output(2, moments_dims, at::dtype<T>());
-    runLayerNorm<T>(X, Y, mean, sig, canonical_axis, epsilon_, &scale_, &bias_, &context_);
+    runLayerNorm<T>(
+        X, Y, mean, sig, canonical_axis, epsilon_, &scale_, &bias_, &context_);
     return true;
   }
 
-  template<typename T>
+  template <typename T>
   static void runLayerNorm(
-    const Tensor& X,
-    Tensor* Y,
-    Tensor* mean,
-    Tensor* sig,
-    int canonical_axis,
-    float epsilon,
-    Tensor* scale_buffer,
-    Tensor* bias_buffer,
-    Context* context
-  ) {
+      const Tensor& X,
+      Tensor* Y,
+      Tensor* mean,
+      Tensor* sig,
+      int canonical_axis,
+      float epsilon,
+      Tensor* scale_buffer,
+      Tensor* bias_buffer,
+      Context* context) {
     CAFFE_ENFORCE_GE(X.dim(), 2, "LayerNorm requires input dim >= 2.");
     const int M = X.size_to_dim(canonical_axis);
     const int N = X.size_from_dim(canonical_axis);
@@ -67,12 +67,19 @@ class LayerNormOp final : public Operator<Context> {
     T* scale_data = scale_buffer->template mutable_data<T>();
     T* bias_data = bias_buffer->template mutable_data<T>();
 
-    const std::array<int, 2> dims = {M, N};
-    const int axis = 1;
+    const std::array<int, 2> X_dims = {M, N};
+    const std::array<int, 2> Y_dims = {M, 1};
     math::Moments<T, Context>(
-        2, dims.data(), 1, &axis, X_data, mean_data, sig_data, context);
+        2, X_dims.data(), Y_dims.data(), X_data, mean_data, sig_data, context);
     ComputeStdDevAndFusedParams<T>(
-        M, mean_data, sig_data, sig_data, scale_data, bias_data, epsilon, context);
+        M,
+        mean_data,
+        sig_data,
+        sig_data,
+        scale_data,
+        bias_data,
+        epsilon,
+        context);
     LayerNormForward<T>(M, N, X_data, scale_data, bias_data, Y_data, context);
   }
 
@@ -132,11 +139,16 @@ class LayerNormGradientOp final : public Operator<Context> {
     const int N = X.size_from_dim(canonical_axis);
 
     auto* dX = Output(0, X.sizes(), at::dtype<T>());
-    ReinitializeTensor(&ds_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
-    ReinitializeTensor(&db_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
-    ReinitializeTensor(&dY_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
-    ReinitializeTensor(&X_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
-    ReinitializeTensor(&bias_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+    ReinitializeTensor(
+        &ds_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+    ReinitializeTensor(
+        &db_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+    ReinitializeTensor(
+        &dY_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+    ReinitializeTensor(
+        &X_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+    ReinitializeTensor(
+        &bias_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
     const T* dY_data = dY.template data<T>();
     const T* X_data = X.template data<T>();
     const T* mean_data = mean.template data<T>();
index 608ece7..9ce1730 100644 (file)
@@ -36,29 +36,32 @@ class MomentsOp final : public Operator<Context> {
           "Axes ids must be smaller than the dimensions of input.");
     }
     const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
-    std::vector<int64_t> Y_dims;
-    Y_dims.reserve(ndim);
+    std::vector<int> Y_dims = X_dims;
+    for (const int axis : axes_) {
+      Y_dims[axis] = 1;
+    }
+    std::vector<std::int64_t> output_dims;
+    output_dims.reserve(ndim);
     std::size_t cur_axis = 0;
     for (int i = 0; i < ndim; ++i) {
       if (cur_axis < axes_.size() && i == axes_[cur_axis]) {
         if (keep_dims_) {
-          Y_dims.push_back(1);
+          output_dims.push_back(1);
         }
         ++cur_axis;
       } else {
-        Y_dims.push_back(X_dims[i]);
+        output_dims.push_back(X_dims[i]);
       }
     }
-    auto* mean = Output(0, Y_dims, at::dtype<T>());
-    auto* variance = Output(1, Y_dims, at::dtype<T>());
+    auto* mean = Output(0, output_dims, at::dtype<T>());
+    auto* var = Output(1, output_dims, at::dtype<T>());
     math::Moments<float, Context>(
         X_dims.size(),
         X_dims.data(),
-        axes_.size(),
-        axes_.data(),
+        Y_dims.data(),
         X.template data<T>(),
         mean->template mutable_data<T>(),
-        variance->template mutable_data<T>(),
+        var->template mutable_data<T>(),
         &context_);
     return true;
   }
index eb8185a..d56ff1a 100644 (file)
@@ -102,26 +102,31 @@ class SpatialBNOp : public Operator<Context> {
       CAFFE_ENFORCE(
           IsInputOutputAlias(4, 2), "Input 4 and Output 2 should be alias.");
 
-      Tensor* running_mean, *running_var;
+      Tensor* running_mean = nullptr;
+      Tensor* running_var = nullptr;
       const auto& mean = Input(EST_MEAN);
       const auto& var = Input(EST_VAR);
       if (mean.numel() != C) {
-       running_mean = Output(RUNNING_MEAN, {C}, at::dtype<T>());
-       C10_LOG_EVERY_MS(WARNING, 1000) << "[Depreacated] Running mean is not initialized in SpatialBatchNorm Op";
-       math::Set<T, Context>(C, T(0), running_mean->template mutable_data<T>(), &context_);
+        running_mean = Output(RUNNING_MEAN, {C}, at::dtype<T>());
+        C10_LOG_EVERY_MS(WARNING, 1000)
+            << "[Depreacated] Running mean is not initialized in "
+               "SpatialBatchNorm Op";
+        math::Set<T, Context>(
+            C, T(0), running_mean->template mutable_data<T>(), &context_);
       } else {
-       running_mean = Output(RUNNING_MEAN, {C}, at::dtype<T>());
+        running_mean = Output(RUNNING_MEAN, {C}, at::dtype<T>());
       }
       if (var.numel() != C) {
         running_var = Output(RUNNING_VAR, {C}, at::dtype<T>());
         math::Set<T, Context>(
             C, T(0), running_var->template mutable_data<T>(), &context_);
-        C10_LOG_EVERY_MS(WARNING, 1000) << "[Deprecated] Running variance is not initialized in SpatialBatchNorm Op";
+        C10_LOG_EVERY_MS(WARNING, 1000)
+            << "[Deprecated] Running variance is not initialized in "
+               "SpatialBatchNorm Op";
       } else {
         running_var = Output(RUNNING_VAR, {C}, at::dtype<T>());
       }
 
-
       T* running_mean_data = running_mean->template mutable_data<T>();
       T* running_var_data = running_var->template mutable_data<T>();
       if (N == 0) {
@@ -144,25 +149,23 @@ class SpatialBNOp : public Operator<Context> {
             saved_rstd_data);
       } else {
         if (order_ == StorageOrder::NCHW) {
-          const std::array<int, 3> dims = {N, C, HxW};
-          const std::array<int, 3> axes = {0, 2};
+          const std::array<int, 3> X_dims_arr = {N, C, HxW};
+          const std::array<int, 3> Y_dims_arr = {1, C, 1};
           math::Moments<T, Context>(
               3,
-              dims.data(),
-              2,
-              axes.data(),
+              X_dims_arr.data(),
+              Y_dims_arr.data(),
               X_data,
               saved_mean_data,
               saved_rstd_data,
               &context_);
         } else {
-          const std::array<int, 2> dims = {N * HxW, C};
-          const int axis = 0;
+          const std::array<int, 2> X_dims_arr = {N * HxW, C};
+          const std::array<int, 2> Y_dims_arr = {1, C};
           math::Moments<T, Context>(
               2,
-              dims.data(),
-              1,
-              &axis,
+              X_dims_arr.data(),
+              Y_dims_arr.data(),
               X_data,
               saved_mean_data,
               saved_rstd_data,
@@ -328,8 +331,8 @@ class SpatialBNGradientOp : public Operator<Context> {
       const auto& dscale_sum = Input(AGGREGATE_SCALE_GRAD);
       const auto& dbias_sum = Input(AGGREGATE_BIAS_GRAD);
       // Note: previously there was alias check to decide whether to call
-      // ResizeLike or not, since we only call Resize when the size does not match
-      // the size of cached Tensor, this check is not necessary
+      // ResizeLike or not, since we only call Resize when the size does not
+      // match the size of cached Tensor, this check is not necessary
       dscale_sizes = dscale_sum.sizes();
       dbias_sizes = dbias_sum.sizes();
     }
index 8268b05..50d01b9 100644 (file)
@@ -255,10 +255,16 @@ void GroupNormDNNLowPOp<T>::DequantizedGroupMomentsNCHW(
   const int inner_size = K * HxW;
   X_dequantized_.resize(size);
   fbgemm::Dequantize<T>(X, X_dequantized_.data(), size, in_qparams_[INPUT]);
-  const std::array<int, 2> dims = {outer_size, inner_size};
-  const int axis = 1;
+  const std::array<int, 2> X_dims = {outer_size, inner_size};
+  const std::array<int, 2> Y_dims = {outer_size, 1};
   math::Moments<float, CPUContext>(
-      2, dims.data(), 1, &axis, X_dequantized_.data(), mu, rsig, &context_);
+      2,
+      X_dims.data(),
+      Y_dims.data(),
+      X_dequantized_.data(),
+      mu,
+      rsig,
+      &context_);
   math::InvStd<float>(outer_size, epsilon_, rsig, rsig, &context_);
 }
 
@@ -276,13 +282,12 @@ void GroupNormDNNLowPOp<T>::DequantizedGroupMomentsNHWC(
   const int outer_size = N * G;
   X_dequantized_.resize(size);
   fbgemm::Dequantize<T>(X, X_dequantized_.data(), size, in_qparams_[INPUT]);
-  const std::array<int, 4> dims = {N, HxW, G, K};
-  const std::array<int, 2> axes = {1, 3};
+  const std::array<int, 4> X_dims = {N, HxW, G, K};
+  const std::array<int, 4> Y_dims = {N, 1, G, 1};
   math::Moments<float, CPUContext>(
       4,
-      dims.data(),
-      2,
-      axes.data(),
+      X_dims.data(),
+      Y_dims.data(),
       X_dequantized_.data(),
       mu,
       rsig,
index fafe6f1..f2ba60f 100644 (file)
@@ -2,6 +2,7 @@ list(APPEND Caffe2_CPU_SRCS
   utils/bench_utils.cc
   utils/cpuid.cc
   utils/math/elementwise.cc
+  utils/math/reduce.cc
   utils/math_cpu.cc
   utils/math_utils.cc
   utils/murmur_hash3.cc
@@ -26,11 +27,13 @@ endif()
 
 set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS}
         utils/math/elementwise.cu
+        utils/math/reduce.cu
         utils/math_gpu.cu
         )
 
 set(Caffe2_HIP_SRCS ${Caffe2_HIP_SRCS}
         utils/math/hip/elementwise.hip
+        utils/math/hip/reduce.hip
         utils/hip/math_gpu.hip
         )
 
index 5ad189c..3c7aaf0 100644 (file)
@@ -1,8 +1,6 @@
 #ifndef CAFFE2_UTILS_FIXED_DIVISOR_H_
 #define CAFFE2_UTILS_FIXED_DIVISOR_H_
 
-#include <stdint.h>
-
 #include <cstdint>
 #include <cstdio>
 #include <cstdlib>
@@ -29,13 +27,16 @@ class FixedDivisor<std::int32_t> {
   FixedDivisor() = default;
 
   explicit FixedDivisor(const std::int32_t d) : d_(d) {
+#ifndef __HIP_PLATFORM_HCC__
     CalcSignedMagic();
+#endif // __HIP_PLATFORM_HCC__
   }
 
   FIXED_DIVISOR_DECL std::int32_t d() const {
     return d_;
   }
 
+#ifndef __HIP_PLATFORM_HCC__
   FIXED_DIVISOR_DECL std::uint64_t magic() const {
     return magic_;
   }
@@ -43,12 +44,17 @@ class FixedDivisor<std::int32_t> {
   FIXED_DIVISOR_DECL int shift() const {
     return shift_;
   }
+#endif // __HIP_PLATFORM_HCC__
 
   /// Calculates `q = n / d`.
   FIXED_DIVISOR_DECL std::int32_t Div(const std::int32_t n) const {
+#ifdef __HIP_PLATFORM_HCC__
+    return n / d_;
+#else // __HIP_PLATFORM_HCC__
     // In lieu of a mulhi instruction being available, perform the
     // work in uint64
     return (int32_t)((magic_ * (uint64_t)n) >> shift_);
+#endif // __HIP_PLATFORM_HCC__
   }
 
   /// Calculates `r = n % d`.
@@ -64,6 +70,7 @@ class FixedDivisor<std::int32_t> {
   }
 
  private:
+#ifndef __HIP_PLATFORM_HCC__
   // Calculates magic multiplicative value and shift amount for calculating `q =
   // n / d` for signed 32-bit integers.
   // Implementation taken from Hacker's Delight section 10.
@@ -107,10 +114,14 @@ class FixedDivisor<std::int32_t> {
     shift_ = p;
     magic_ = (std::uint64_t)(std::uint32_t)magic;
   }
+#endif // __HIP_PLATFORM_HCC__
 
   std::int32_t d_ = 1;
+
+#ifndef __HIP_PLATFORM_HCC__
   std::uint64_t magic_;
   int shift_;
+#endif // __HIP_PLATFORM_HCC__
 };
 
 } // namespace caffe2
index af0e668..6d6986f 100644 (file)
@@ -16,6 +16,7 @@ extern "C" {
 #include "caffe2/core/common.h"
 #include "caffe2/core/types.h"
 #include "caffe2/utils/math/elementwise.h"
+#include "caffe2/utils/math/reduce.h"
 #include "caffe2/utils/math_utils.h"
 
 namespace caffe2 {
@@ -254,18 +255,6 @@ CAFFE2_API void Broadcast(
     T* Y,
     Context* context);
 
-// Computes mean and variance over axes.
-template <typename T, class Context>
-CAFFE2_API void Moments(
-    const int num_dims,
-    const int* dims,
-    const int num_axes,
-    const int* axes,
-    const T* X,
-    T* mean,
-    T* variance,
-    Context* context);
-
 // Computes inv_std from variance.
 template <typename T, class Context>
 CAFFE2_API void InvStd(
diff --git a/caffe2/utils/math/reduce.cc b/caffe2/utils/math/reduce.cc
new file mode 100644 (file)
index 0000000..c654834
--- /dev/null
@@ -0,0 +1,145 @@
+#include "caffe2/utils/math/reduce.h"
+
+#include <algorithm>
+#include <cstring>
+#include <functional>
+#include <numeric>
+#include <vector>
+
+#include "caffe2/core/context.h"
+#include "caffe2/utils/eigen_utils.h"
+#include "caffe2/utils/math_utils.h"
+
+namespace caffe2 {
+namespace math {
+
+namespace {
+
+template <typename T>
+C10_EXPORT void
+RowwiseMoments(const int rows, const int cols, const T* X, T* mean, T* var) {
+  ConstEigenArrayMap<T> X_arr(X, cols, rows);
+  EigenVectorArrayMap<T> mean_arr(mean, rows);
+  EigenVectorArrayMap<T> var_arr(var, rows);
+  mean_arr = X_arr.colwise().mean();
+  var_arr = X_arr.square().colwise().mean() - mean_arr.square().transpose();
+}
+
+template <typename T>
+C10_EXPORT void
+ColwiseMoments(const int rows, const int cols, const T* X, T* mean, T* var) {
+  std::memset(mean, 0, sizeof(T) * cols);
+  std::memset(var, 0, sizeof(T) * cols);
+  ConstEigenArrayMap<T> X_arr(X, cols, rows);
+  EigenVectorArrayMap<T> mean_arr(mean, cols);
+  EigenVectorArrayMap<T> var_arr(var, cols);
+  // Eigen rowwise reduction is about 10 times slower than this for-loop.
+  for (int i = 0; i < rows; ++i) {
+    mean_arr += X_arr.col(i);
+    var_arr += X_arr.col(i).square();
+  }
+  const T scale = T(1) / static_cast<T>(rows);
+  mean_arr *= scale;
+  var_arr = var_arr * scale - mean_arr.square();
+}
+
+template <typename T>
+C10_EXPORT void BothEndsMoments(
+    const int pre,
+    const int mid,
+    const int nxt,
+    const T* X,
+    T* mean,
+    T* var) {
+  std::memset(mean, 0, sizeof(T) * mid);
+  std::memset(var, 0, sizeof(T) * mid);
+  EigenVectorArrayMap<T> mean_arr(mean, mid);
+  EigenVectorArrayMap<T> var_arr(var, mid);
+  ConstEigenArrayMap<T> X_arr(X, nxt, pre * mid);
+  for (int i = 0; i < pre; ++i) {
+    for (int j = 0; j < mid; ++j) {
+      const int c = i * mid + j;
+      mean_arr(j) += X_arr.col(c).sum();
+      var_arr(j) += X_arr.col(c).square().sum();
+    }
+  }
+  const T scale = T(1) / static_cast<T>(pre * nxt);
+  mean_arr *= scale;
+  var_arr = var_arr * scale - mean_arr.square();
+}
+
+template <typename T>
+C10_EXPORT void MomentsImpl(
+    const int ndim,
+    const int* X_dims,
+    const int* Y_dims,
+    const T* X,
+    T* mean,
+    T* var,
+    CPUContext* /* context */) {
+  const int X_size =
+      std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies<int>());
+  const int Y_size =
+      std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
+  if (X_size == 0) {
+    std::memset(mean, 0, sizeof(T) * Y_size);
+    std::memset(var, 0, sizeof(T) * Y_size);
+    return;
+  }
+  if (std::equal(X_dims, X_dims + ndim, Y_dims)) {
+    std::memcpy(mean, X, sizeof(T) * Y_size);
+    std::memset(var, 0, sizeof(T) * Y_size);
+    return;
+  }
+  int rows;
+  int cols;
+  if (utils::IsRowwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) {
+    RowwiseMoments<T>(rows, cols, X, mean, var);
+    return;
+  }
+  if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) {
+    ColwiseMoments<T>(rows, cols, X, mean, var);
+    return;
+  }
+  int pre;
+  int mid;
+  int nxt;
+  if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &pre, &mid, &nxt)) {
+    BothEndsMoments<T>(pre, mid, nxt, X, mean, var);
+    return;
+  }
+  std::memset(mean, 0, sizeof(T) * Y_size);
+  std::memset(var, 0, sizeof(T) * Y_size);
+  std::vector<int> index(ndim, 0);
+  for (int X_index = 0; X_index < X_size; ++X_index) {
+    const int Y_index = utils::GetIndexFromDims(ndim, Y_dims, index.data());
+    mean[Y_index] += X[X_index];
+    var[Y_index] += X[X_index] * X[X_index];
+    utils::IncreaseIndexInDims(ndim, X_dims, index.data());
+  }
+  const T scale = static_cast<T>(Y_size) / static_cast<T>(X_size);
+  EigenVectorArrayMap<T> mean_arr(mean, Y_size);
+  EigenVectorArrayMap<T> var_arr(var, Y_size);
+  mean_arr *= scale;
+  var_arr = var_arr * scale - mean_arr.square();
+}
+
+} // namespace
+
+#define CAFFE2_SPECIALIZED_MOMENTS(T)                            \
+  template <>                                                    \
+  C10_EXPORT void Moments<T, CPUContext>(                        \
+      const int ndim,                                            \
+      const int* X_dims,                                         \
+      const int* Y_dims,                                         \
+      const T* X,                                                \
+      T* mean,                                                   \
+      T* var,                                                    \
+      CPUContext* context) {                                     \
+    MomentsImpl<T>(ndim, X_dims, Y_dims, X, mean, var, context); \
+  }
+CAFFE2_SPECIALIZED_MOMENTS(float)
+#undef CAFFE2_SPECIALIZED_MOMENTS
+
+} // namespace math
+} // namespace caffe2
diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu
new file mode 100644 (file)
index 0000000..f597ec7
--- /dev/null
@@ -0,0 +1,287 @@
+#include "caffe2/utils/math.h"
+
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <vector>
+
+#include <cub/block/block_reduce.cuh>
+#include <cub/cub.cuh>
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/fixed_divisor.h"
+#include "caffe2/utils/math_utils.h"
+
+namespace caffe2 {
+namespace math {
+
+namespace {
+
+template <typename T>
+using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+using BlockReduce2D = cub::
+    BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
+
+template <typename T>
+__global__ void
+RowwiseMomentsCUDAKernel(const int cols, const T* X, T* mean, T* var) {
+  __shared__ typename BlockReduce<T>::TempStorage m_storage;
+  __shared__ typename BlockReduce<T>::TempStorage v_storage;
+  const T scale = T(1) / static_cast<T>(cols);
+  const int r = blockIdx.x;
+  T m_val = 0;
+  T v_val = 0;
+  for (int c = threadIdx.x; c < cols; c += blockDim.x) {
+    const int X_index = r * cols + c;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    m_val += __ldg(X + X_index);
+    v_val += __ldg(X + X_index) * __ldg(X + X_index);
+#else
+    m_val += X[X_index];
+    v_val += X[X_index] * X[X_index];
+#endif
+  }
+  m_val = BlockReduce<T>(m_storage).Sum(m_val);
+  v_val = BlockReduce<T>(v_storage).Sum(v_val);
+  if (threadIdx.x == 0) {
+    const T mu = m_val * scale;
+    mean[r] = mu;
+    var[r] = v_val * scale - mu * mu;
+  }
+}
+
+template <typename T>
+__global__ void ColwiseMomentsCUDAKernel(
+    const int rows,
+    const int cols,
+    const T* X,
+    T* mean,
+    T* var) {
+  __shared__ typename BlockReduce<T>::TempStorage m_storage;
+  __shared__ typename BlockReduce<T>::TempStorage v_storage;
+  const T scale = T(1) / static_cast<T>(rows);
+  const int c = blockIdx.x;
+  T m_val = 0;
+  T v_val = 0;
+  for (int r = threadIdx.x; r < rows; r += blockDim.x) {
+    const int X_index = r * cols + c;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    m_val += __ldg(X + X_index);
+    v_val += __ldg(X + X_index) * __ldg(X + X_index);
+#else
+    m_val += X[X_index];
+    v_val += X[X_index] * X[X_index];
+#endif
+  }
+  m_val = BlockReduce<T>(m_storage).Sum(m_val);
+  v_val = BlockReduce<T>(v_storage).Sum(v_val);
+  if (threadIdx.x == 0) {
+    const T mu = m_val * scale;
+    mean[c] = mu;
+    var[c] = v_val * scale - mu * mu;
+  }
+}
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+__global__ void BothEndsMomentsCUDAKernel(
+    const int M,
+    const int N,
+    const int K,
+    const T* X,
+    T* mean,
+    T* var) {
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage m_storage;
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage v_storage;
+  const T scale = T(1) / static_cast<T>(M * K);
+  const int n = blockIdx.x;
+  T m_val = 0;
+  T v_val = 0;
+  for (int m = threadIdx.x; m < M; m += blockDim.x) {
+    for (int k = threadIdx.y; k < K; k += blockDim.y) {
+      const int X_index = (m * N + n) * K + k;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+      m_val += __ldg(X + X_index);
+      v_val += __ldg(X + X_index) * __ldg(X + X_index);
+#else
+      m_val += X[X_index];
+      v_val += X[X_index] * X[X_index];
+#endif
+    }
+  }
+  m_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(m_storage).Sum(m_val);
+  v_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(v_storage).Sum(v_val);
+  if (threadIdx.x == 0 && threadIdx.y == 0) {
+    const T mu = m_val * scale;
+    mean[n] = mu;
+    var[n] = v_val * scale - mu * mu;
+  }
+}
+
+template <typename T, int D>
+__global__ void MomentsCUDAKernel(
+    const int inner_size,
+    const SimpleArray<int, D> X_strides,
+    const SimpleArray<FixedDivisor<int>, D> Y_dims,
+    const T* X,
+    T* mean,
+    T* var) {
+  __shared__ typename BlockReduce<T>::TempStorage m_storage;
+  __shared__ typename BlockReduce<T>::TempStorage v_storage;
+  const T scale = T(1) / static_cast<T>(inner_size);
+  const int x = blockIdx.x;
+  T m_val = 0;
+  T v_val = 0;
+  for (int y = threadIdx.x; y < inner_size; y += blockDim.x) {
+    int X_index = 0;
+    int Y_index = x * inner_size + y;
+#pragma unroll
+    for (int d = D - 1; d >= 0; --d) {
+      int r;
+      Y_dims.data[d].DivMod(Y_index, &Y_index, &r);
+      X_index += r * X_strides.data[d];
+    }
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    m_val += __ldg(X + X_index);
+    v_val += __ldg(X + X_index) * __ldg(X + X_index);
+#else
+    m_val += X[X_index];
+    v_val += X[X_index] * X[X_index];
+#endif
+  }
+  m_val = BlockReduce<T>(m_storage).Sum(m_val);
+  v_val = BlockReduce<T>(v_storage).Sum(v_val);
+  if (threadIdx.x == 0) {
+    const T mu = m_val * scale;
+    mean[x] = mu;
+    var[x] = v_val * scale - mu * mu;
+  }
+}
+
+template <typename T, int D>
+CAFFE2_CUDA_EXPORT void MomentsCUDAImpl(
+    const int outer_size,
+    const int inner_size,
+    const int* dims,
+    const int* axes,
+    const T* X,
+    T* mean,
+    T* var,
+    CUDAContext* context) {
+  SimpleArray<int, D> X_strides;
+  SimpleArray<FixedDivisor<int>, D> Y_dims;
+  utils::ComputeTransposedStrides(D, dims, axes, X_strides.data);
+  for (int i = 0; i < D; ++i) {
+    Y_dims.data[i] = FixedDivisor<int>(dims[axes[i]]);
+  }
+  MomentsCUDAKernel<T, D>
+      <<<outer_size, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          inner_size, X_strides, Y_dims, X, mean, var);
+}
+
+template <typename T>
+CAFFE2_CUDA_EXPORT void MomentsCUDA(
+    const int ndim,
+    const int* X_dims,
+    const int* Y_dims,
+    const T* X,
+    T* mean,
+    T* var,
+    CUDAContext* context) {
+  CAFFE_ENFORCE(utils::CheckReduceDims(ndim, X_dims, Y_dims));
+  const int X_size =
+      std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies<int>());
+  const int Y_size =
+      std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
+  if (X_size == 0) {
+    Set<T, CUDAContext>(Y_size, T(0), mean, context);
+    Set<T, CUDAContext>(Y_size, T(0), var, context);
+    return;
+  }
+  if (std::equal(X_dims, X_dims + ndim, Y_dims)) {
+    cudaMemcpyAsync(
+        mean,
+        X,
+        sizeof(T) * X_size,
+        cudaMemcpyDeviceToDevice,
+        context->cuda_stream());
+    Set<T, CUDAContext>(Y_size, T(0), var, context);
+    return;
+  }
+  int rows;
+  int cols;
+  if (utils::IsRowwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) {
+    RowwiseMomentsCUDAKernel<T>
+        <<<rows, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+            cols, X, mean, var);
+    return;
+  }
+  if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) {
+    ColwiseMomentsCUDAKernel<T>
+        <<<cols, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+            rows, cols, X, mean, var);
+    return;
+  }
+  int M;
+  int N;
+  int K;
+  if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) {
+    if (K >= 128) {
+      BothEndsMomentsCUDAKernel<T, 1, 128>
+          <<<N, dim3(1, 128), 0, context->cuda_stream()>>>(
+              M, N, K, X, mean, var);
+    } else if (K >= 64) {
+      BothEndsMomentsCUDAKernel<T, 2, 64>
+          <<<N, dim3(2, 64), 0, context->cuda_stream()>>>(
+              M, N, K, X, mean, var);
+    } else if (K >= 32) {
+      BothEndsMomentsCUDAKernel<T, 4, 32>
+          <<<N, dim3(4, 32), 0, context->cuda_stream()>>>(
+              M, N, K, X, mean, var);
+    } else {
+      BothEndsMomentsCUDAKernel<T, 8, 16>
+          <<<N, dim3(8, 16), 0, context->cuda_stream()>>>(
+              M, N, K, X, mean, var);
+    }
+    return;
+  }
+  std::vector<int> axes(ndim);
+  utils::ComputeTransposeAxesForReduceOp(ndim, Y_dims, axes.data());
+  const int outer_size = Y_size;
+  const int inner_size = X_size / Y_size;
+  DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1(
+      ndim,
+      MomentsCUDAImpl,
+      T,
+      outer_size,
+      inner_size,
+      X_dims,
+      axes.data(),
+      X,
+      mean,
+      var,
+      context);
+}
+
+} // namespace
+
+#define CAFFE2_SPECIALIZED_CUDA_MOMENTS(T)                       \
+  template <>                                                    \
+  CAFFE2_CUDA_EXPORT void Moments<T, CUDAContext>(               \
+      const int ndim,                                            \
+      const int* X_dims,                                         \
+      const int* Y_dims,                                         \
+      const T* X,                                                \
+      T* mean,                                                   \
+      T* var,                                                    \
+      CUDAContext* context) {                                    \
+    MomentsCUDA<T>(ndim, X_dims, Y_dims, X, mean, var, context); \
+  }
+CAFFE2_SPECIALIZED_CUDA_MOMENTS(float)
+#undef CAFFE2_SPECIALIZED_CUDA_MOMENTS
+
+} // namespace math
+} // namespace caffe2
diff --git a/caffe2/utils/math/reduce.h b/caffe2/utils/math/reduce.h
new file mode 100644 (file)
index 0000000..fce3c95
--- /dev/null
@@ -0,0 +1,24 @@
+#ifndef CAFFE2_UTILS_MATH_REDUCE_H_
+#define CAFFE2_UTILS_MATH_REDUCE_H_
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/types.h"
+
+namespace caffe2 {
+namespace math {
+
+// Computes mean and variance over axes.
+template <typename T, class Context>
+CAFFE2_API void Moments(
+    const int ndims,
+    const int* X_dims,
+    const int* Y_dims,
+    const T* X,
+    T* mean,
+    T* var,
+    Context* context);
+
+} // namespace math
+} // namespace caffe2
+
+#endif // CAFFE2_UTILS_MATH_REDUCE_H_
index d59e004..6b9a506 100644 (file)
@@ -1536,152 +1536,6 @@ CAFFE2_SPECIALIZED_BROADCAST(float)
 CAFFE2_SPECIALIZED_BROADCAST(double)
 #undef CAFFE2_SPECIALIZED_BROADCAST
 
-namespace {
-
-template <typename T>
-C10_EXPORT void RowwiseMoments(
-    const int rows,
-    const int cols,
-    const T* X,
-    T* mean,
-    T* variance) {
-  ConstEigenArrayMap<T> X_arr(X, cols, rows);
-  EigenVectorArrayMap<T> mean_arr(mean, rows);
-  EigenVectorArrayMap<T> var_arr(variance, rows);
-  mean_arr = X_arr.colwise().mean();
-  var_arr = X_arr.square().colwise().mean() - mean_arr.square().transpose();
-}
-
-template <typename T>
-C10_EXPORT void ColwiseMoments(
-    const int rows,
-    const int cols,
-    const T* X,
-    T* mean,
-    T* variance) {
-  std::memset(mean, 0, sizeof(T) * cols);
-  std::memset(variance, 0, sizeof(T) * cols);
-  ConstEigenArrayMap<T> X_arr(X, cols, rows);
-  EigenVectorArrayMap<T> mean_arr(mean, cols);
-  EigenVectorArrayMap<T> var_arr(variance, cols);
-  // Eigen rowwise reduction is about 10 times slower than this for-loop.
-  for (int i = 0; i < rows; ++i) {
-    mean_arr += X_arr.col(i);
-    var_arr += X_arr.col(i).square();
-  }
-  const T scale = T(1) / static_cast<T>(rows);
-  mean_arr *= scale;
-  var_arr = var_arr * scale - mean_arr.square();
-}
-
-template <typename T>
-C10_EXPORT void BothEndsMoments(
-    const int pre,
-    const int mid,
-    const int nxt,
-    const T* X,
-    T* mean,
-    T* variance) {
-  std::memset(mean, 0, sizeof(T) * mid);
-  std::memset(variance, 0, sizeof(T) * mid);
-  EigenVectorArrayMap<T> mean_arr(mean, mid);
-  EigenVectorArrayMap<T> var_arr(variance, mid);
-  ConstEigenArrayMap<T> X_arr(X, nxt, pre * mid);
-  for (int i = 0; i < pre; ++i) {
-    for (int j = 0; j < mid; ++j) {
-      const int c = i * mid + j;
-      mean_arr(j) += X_arr.col(c).sum();
-      var_arr(j) += X_arr.col(c).square().sum();
-    }
-  }
-  const T scale = T(1) / static_cast<T>(pre * nxt);
-  mean_arr *= scale;
-  var_arr = var_arr * scale - mean_arr.square();
-}
-
-template <typename T>
-C10_EXPORT void MomentsImpl(
-    const int num_dims,
-    const int* dims,
-    const int num_axes,
-    const int* axes,
-    const T* X,
-    T* mean,
-    T* variance,
-    CPUContext* context) {
-  std::vector<int> Y_dims_vector(dims, dims + num_dims);
-  for (int i = 0; i < num_axes; ++i) {
-    Y_dims_vector[axes[i]] = 1;
-  }
-  const int* X_dims = dims;
-  const int* Y_dims = Y_dims_vector.data();
-  const int X_size =
-      std::accumulate(X_dims, X_dims + num_dims, 1, std::multiplies<int>());
-  const int Y_size =
-      std::accumulate(Y_dims, Y_dims + num_dims, 1, std::multiplies<int>());
-  if (X_size == 0) {
-    std::memset(mean, 0, sizeof(T) * Y_size);
-    std::memset(variance, 0, sizeof(T) * Y_size);
-    return;
-  }
-  if (std::equal(X_dims, X_dims + num_dims, Y_dims)) {
-    std::memcpy(mean, X, sizeof(T) * Y_size);
-    std::memset(variance, 0, sizeof(T) * Y_size);
-    return;
-  }
-  int rows;
-  int cols;
-  if (utils::IsRowwiseReduce(num_dims, X_dims, Y_dims, &rows, &cols)) {
-    RowwiseMoments<T>(rows, cols, X, mean, variance);
-    return;
-  }
-  if (utils::IsColwiseReduce(num_dims, X_dims, Y_dims, &rows, &cols)) {
-    ColwiseMoments<T>(rows, cols, X, mean, variance);
-    return;
-  }
-  int pre;
-  int mid;
-  int nxt;
-  if (utils::IsBothEndsReduce(num_dims, X_dims, Y_dims, &pre, &mid, &nxt)) {
-    BothEndsMoments<T>(pre, mid, nxt, X, mean, variance);
-    return;
-  }
-  Set<T, CPUContext>(Y_size, T(0), mean, context);
-  Set<T, CPUContext>(Y_size, T(0), variance, context);
-  std::vector<int> index(num_dims, 0);
-  for (int X_index = 0; X_index < X_size; ++X_index) {
-    const int Y_index = utils::GetIndexFromDims(num_dims, Y_dims, index.data());
-    mean[Y_index] += X[X_index];
-    variance[Y_index] += X[X_index] * X[X_index];
-    utils::IncreaseIndexInDims(num_dims, dims, index.data());
-  }
-  const T scale = static_cast<T>(Y_size) / static_cast<T>(X_size);
-  EigenVectorArrayMap<T> mean_arr(mean, Y_size);
-  EigenVectorArrayMap<T> var_arr(variance, Y_size);
-  mean_arr *= scale;
-  var_arr =
-      var_arr * scale - ConstEigenVectorArrayMap<T>(mean, Y_size).square();
-}
-
-} // namespace
-
-#define CAFFE2_SPECIALIZED_MOMENTS(T)                                \
-  template <>                                                        \
-  C10_EXPORT void Moments<T, CPUContext>(                            \
-      const int num_dims,                                            \
-      const int* dims,                                               \
-      const int num_axes,                                            \
-      const int* axes,                                               \
-      const T* X,                                                    \
-      T* mean,                                                       \
-      T* variance,                                                   \
-      CPUContext* context) {                                         \
-    MomentsImpl<T>(                                                  \
-        num_dims, dims, num_axes, axes, X, mean, variance, context); \
-  }
-CAFFE2_SPECIALIZED_MOMENTS(float)
-#undef CAFFE2_SPECIALIZED_MOMENTS
-
 #define CAFFE2_SPECIALIZED_INV_STD(T)                            \
   template <>                                                    \
   void InvStd<T, CPUContext>(                                    \
index 0229969..131b409 100644 (file)
@@ -3569,7 +3569,7 @@ CAFFE2_CUDA_EXPORT void ReduceTensorCUDA(
   }
   if (utils::IsColwiseReduce(num_dims, X_dims, Y_dims, &rows, &cols)) {
     ColwiseReduceKernel<T>
-        <<<std::min(rows, CAFFE_MAXIMUM_NUM_BLOCKS),
+        <<<std::min(cols, CAFFE_MAXIMUM_NUM_BLOCKS),
            CAFFE_CUDA_NUM_THREADS,
            0,
            context->cuda_stream()>>>(rows, cols, reducer, init, alpha, X, Y);
@@ -3813,242 +3813,6 @@ CAFFE2_SPECIALIZED_CUDA_BROADCAST(double)
 namespace {
 
 template <typename T>
-__global__ void RowwiseMomentsCUDAKernel(
-    const int rows,
-    const int cols,
-    const T* X,
-    T* mean,
-    T* variance) {
-  __shared__ typename BlockReduce<T>::TempStorage m_storage;
-  __shared__ typename BlockReduce<T>::TempStorage v_storage;
-  const T scale = T(1) / static_cast<T>(cols);
-  for (int i = blockIdx.x; i < rows; i += gridDim.x) {
-    T m_val = 0;
-    T v_val = 0;
-    for (int j = threadIdx.x; j < cols; j += blockDim.x) {
-      const int X_index = i * cols + j;
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-      m_val += __ldg(X + X_index);
-      v_val += __ldg(X + X_index) * __ldg(X + X_index);
-#else
-      m_val += X[X_index];
-      v_val += X[X_index] * X[X_index];
-#endif
-    }
-    m_val = BlockReduce<T>(m_storage).Sum(m_val);
-    v_val = BlockReduce<T>(v_storage).Sum(v_val);
-    if (threadIdx.x == 0) {
-      const T mu = m_val * scale;
-      mean[i] = mu;
-      variance[i] = v_val * scale - mu * mu;
-    }
-    __syncthreads();
-  }
-}
-
-template <typename T>
-__global__ void ColwiseMomentsCUDAKernel(
-    const int rows,
-    const int cols,
-    const T* X,
-    T* mean,
-    T* variance) {
-  __shared__ typename BlockReduce<T>::TempStorage m_storage;
-  __shared__ typename BlockReduce<T>::TempStorage v_storage;
-  const T scale = T(1) / static_cast<T>(rows);
-  for (int i = blockIdx.x; i < cols; i += gridDim.x) {
-    T m_val = 0;
-    T v_val = 0;
-    for (int j = threadIdx.x; j < rows; j += blockDim.x) {
-      const int X_index = j * cols + i;
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-      m_val += __ldg(X + X_index);
-      v_val += __ldg(X + X_index) * __ldg(X + X_index);
-#else
-      m_val += X[X_index];
-      v_val += X[X_index] * X[X_index];
-#endif
-    }
-    m_val = BlockReduce<T>(m_storage).Sum(m_val);
-    v_val = BlockReduce<T>(v_storage).Sum(v_val);
-    if (threadIdx.x == 0) {
-      const T mu = m_val * scale;
-      mean[i] = mu;
-      variance[i] = v_val * scale - mu * mu;
-    }
-    __syncthreads();
-  }
-}
-
-template <typename T, int D>
-__global__ void MomentsCUDAKernel(
-    const int outer_size,
-    const int inner_size,
-    SimpleArray<int, D> X_strides,
-    SimpleArray<FIXED_DIVISOR, D> Y_dims,
-    const T* X,
-    T* mean,
-    T* variance) {
-  __shared__ typename BlockReduce<T>::TempStorage m_storage;
-  __shared__ typename BlockReduce<T>::TempStorage v_storage;
-  const T scale = T(1) / static_cast<T>(inner_size);
-  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
-    T m_val = 0;
-    T v_val = 0;
-    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
-      int X_index = 0;
-      int Y_index = i * inner_size + j;
-#pragma unroll
-      for (int d = D - 1; d >= 0; --d) {
-        int r;
-        FIXED_DIVISOR_DIV_MOD(Y_dims.data[d], Y_index, &Y_index, &r);
-        X_index += r * X_strides.data[d];
-      }
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-      m_val += __ldg(X + X_index);
-      v_val += __ldg(X + X_index) * __ldg(X + X_index);
-#else
-      m_val += X[X_index];
-      v_val += X[X_index] * X[X_index];
-#endif
-    }
-    m_val = BlockReduce<T>(m_storage).Sum(m_val);
-    v_val = BlockReduce<T>(v_storage).Sum(v_val);
-    if (threadIdx.x == 0) {
-      const T mu = m_val * scale;
-      mean[i] = mu;
-      variance[i] = v_val * scale - mu * mu;
-    }
-    __syncthreads();
-  }
-}
-
-template <typename T, int D>
-CAFFE2_CUDA_EXPORT void MomentsCUDAImpl(
-    const int outer_size,
-    const int inner_size,
-    const int* dims,
-    const int* axes,
-    const T* X,
-    T* mean,
-    T* variance,
-    CUDAContext* context) {
-  SimpleArray<int, D> X_strides;
-  SimpleArray<FIXED_DIVISOR, D> Y_dims;
-  utils::ComputeTransposedStrides(D, dims, axes, X_strides.data);
-  for (int i = 0; i < D; ++i) {
-    Y_dims.data[i] = FIXED_DIVISOR(dims[axes[i]]);
-  }
-  MomentsCUDAKernel<T, D>
-      <<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context->cuda_stream()>>>(
-          outer_size, inner_size, X_strides, Y_dims, X, mean, variance);
-}
-
-template <typename T>
-CAFFE2_CUDA_EXPORT void MomentsCUDA(
-    const int num_dims,
-    const int* dims,
-    const int num_axes,
-    const int* axes,
-    const T* X,
-    T* mean,
-    T* variance,
-    CUDAContext* context) {
-  CAFFE_ENFORCE_LE(num_axes, num_dims);
-  std::vector<int> Y_dims_vector(dims, dims + num_dims);
-  for (int i = 0; i < num_axes; ++i) {
-    Y_dims_vector[axes[i]] = 1;
-  }
-  const int* X_dims = dims;
-  const int* Y_dims = Y_dims_vector.data();
-  const int X_size =
-      std::accumulate(X_dims, X_dims + num_dims, 1, std::multiplies<int>());
-  const int Y_size =
-      std::accumulate(Y_dims, Y_dims + num_dims, 1, std::multiplies<int>());
-  if (X_size == 0) {
-    Set<T, CUDAContext>(Y_size, T(0), mean, context);
-    Set<T, CUDAContext>(Y_size, T(0), variance, context);
-    return;
-  }
-  if (std::equal(X_dims, X_dims + num_dims, Y_dims)) {
-    cudaMemcpyAsync(
-        mean,
-        X,
-        sizeof(T) * X_size,
-        cudaMemcpyDeviceToDevice,
-        context->cuda_stream());
-    Set<T, CUDAContext>(Y_size, T(0), variance, context);
-    return;
-  }
-  int rows;
-  int cols;
-  if (utils::IsRowwiseReduce(num_dims, X_dims, Y_dims, &rows, &cols)) {
-    RowwiseMomentsCUDAKernel<T>
-        <<<std::min(rows, CAFFE_MAXIMUM_NUM_BLOCKS),
-           CAFFE_CUDA_NUM_THREADS,
-           0,
-           context->cuda_stream()>>>(rows, cols, X, mean, variance);
-    return;
-  }
-  if (utils::IsColwiseReduce(num_dims, X_dims, Y_dims, &rows, &cols)) {
-    ColwiseMomentsCUDAKernel<T>
-        <<<std::min(rows, CAFFE_MAXIMUM_NUM_BLOCKS),
-           CAFFE_CUDA_NUM_THREADS,
-           0,
-           context->cuda_stream()>>>(rows, cols, X, mean, variance);
-    return;
-  }
-  std::vector<int> transpose_axes(num_dims);
-  utils::ComputeTransposeAxesForReduceOp(
-      num_dims, num_axes, axes, transpose_axes.data());
-  const int pivot = num_dims - num_axes;
-  int outer_size = 1;
-  for (int i = 0; i < pivot; ++i) {
-    outer_size *= dims[transpose_axes[i]];
-  }
-  int inner_size = 1;
-  for (int i = pivot; i < num_dims; ++i) {
-    inner_size *= dims[transpose_axes[i]];
-  }
-  DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1(
-      num_dims,
-      MomentsCUDAImpl,
-      T,
-      outer_size,
-      inner_size,
-      dims,
-      transpose_axes.data(),
-      X,
-      mean,
-      variance,
-      context);
-}
-
-} // namespace
-
-#define CAFFE2_SPECIALIZED_CUDA_MOMENTS(T)                           \
-  template <>                                                        \
-  CAFFE2_CUDA_EXPORT void Moments<T, CUDAContext>(                   \
-      const int num_dims,                                            \
-      const int* dims,                                               \
-      const int num_axes,                                            \
-      const int* axes,                                               \
-      const T* X,                                                    \
-      T* mean,                                                       \
-      T* variance,                                                   \
-      CUDAContext* context) {                                        \
-    MomentsCUDA<T>(                                                  \
-        num_dims, dims, num_axes, axes, X, mean, variance, context); \
-  }
-CAFFE2_SPECIALIZED_CUDA_MOMENTS(float)
-#undef CAFFE2_SPECIALIZED_CUDA_MOMENTS
-
-namespace {
-
-template <typename T>
 __global__ void
 InvStdCUDAKernel(const int N, const T epsilon, const T* var, T* inv_std);
 
index de1cbbd..80d7ff0 100644 (file)
@@ -726,133 +726,6 @@ TEST_F(BroadcastGPUTest, BroadcastGPUFloatTest) {
       {1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 2.0f});
 }
 
-class MomentsGPUTest : public testing::Test {
- protected:
-  void SetUp() override {
-    if (!HasCudaGPU()) {
-      return;
-    }
-    option_.set_device_type(PROTO_CUDA);
-    cuda_context_ = make_unique<CUDAContext>(option_);
-    Blob* blob_x = ws_.CreateBlob("X");
-    Blob* blob_mean = ws_.CreateBlob("mean");
-    Blob* blob_variance = ws_.CreateBlob("variance");
-    X_ = BlobGetMutableTensor(blob_x, CUDA);
-    mean_ = BlobGetMutableTensor(blob_mean, CUDA);
-    variance_ = BlobGetMutableTensor(blob_variance, CUDA);
-  }
-
-  void SetUpData(
-      const std::vector<int>& X_dims,
-      const std::vector<int>& axes,
-      const std::vector<float>& X_data) {
-    std::vector<int> Y_dims = X_dims;
-    for (const int axis : axes) {
-      Y_dims[axis] = 1;
-    }
-    X_->Resize(X_dims);
-    mean_->Resize(Y_dims);
-    variance_->Resize(Y_dims);
-    ASSERT_EQ(X_data.size(), X_->numel());
-    cuda_context_->CopyFromCPU<float>(
-        X_data.size(), X_data.data(), X_->mutable_data<float>());
-  }
-
-  void VerifyResult(
-      const std::vector<float>& mean_data,
-      const std::vector<float>& variance_data) {
-    Blob* blob_mean_host = ws_.CreateBlob("mean_host");
-    auto* mean_host = BlobGetMutableTensor(blob_mean_host, CPU);
-    mean_host->CopyFrom(*mean_);
-    Blob* blob_variance_host = ws_.CreateBlob("variance_host");
-    auto* variance_host = BlobGetMutableTensor(blob_variance_host, CPU);
-    variance_host->CopyFrom(*variance_);
-
-    ASSERT_EQ(mean_data.size(), mean_host->numel());
-    for (std::size_t i = 0; i < mean_data.size(); ++i) {
-      EXPECT_FLOAT_EQ(mean_data[i], mean_host->data<float>()[i]);
-    }
-    ASSERT_EQ(variance_data.size(), variance_host->numel());
-    for (std::size_t i = 0; i < variance_data.size(); ++i) {
-      EXPECT_NEAR(variance_data[i], variance_host->data<float>()[i], kEps);
-    }
-  }
-
-  void RunMomentsTest(
-      const std::vector<int>& X_dims,
-      const std::vector<int>& axes,
-      const std::vector<float>& X_data,
-      const std::vector<float>& mean_data,
-      const std::vector<float>& variance_data) {
-    SetUpData(X_dims, axes, X_data);
-    math::Moments<float, CUDAContext>(
-        X_dims.size(),
-        X_dims.data(),
-        axes.size(),
-        axes.data(),
-        X_->data<float>(),
-        mean_->mutable_data<float>(),
-        variance_->mutable_data<float>(),
-        cuda_context_.get());
-    VerifyResult(mean_data, variance_data);
-  }
-
-  Workspace ws_;
-  DeviceOption option_;
-  std::unique_ptr<CUDAContext> cuda_context_;
-  Tensor* X_ = nullptr;
-  Tensor* mean_ = nullptr;
-  Tensor* variance_ = nullptr;
-};
-
-TEST_F(MomentsGPUTest, MomentsGPUFloatTest) {
-  if (!HasCudaGPU()) {
-    return;
-  }
-  // Test for 1D tensor.
-  RunMomentsTest({3}, {0}, {1.0f, 2.0f, 3.0f}, {2.0f}, {2.0f / 3.0f});
-
-  // Test for 2D Tensor.
-  RunMomentsTest(
-      {2, 3},
-      {1},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
-      {2.0f, 5.0f},
-      {2.0f / 3.0f, 2.0f / 3.0f});
-  RunMomentsTest(
-      {2, 3},
-      {0},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
-      {2.5f, 3.5f, 4.5f},
-      {2.25f, 2.25f, 2.25f});
-  RunMomentsTest(
-      {2, 3},
-      {0, 1},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
-      {3.5f},
-      {35.0f / 12.0f});
-
-  // Test for 3D tensor.
-  RunMomentsTest(
-      {2, 2, 2},
-      {1, 2},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
-      {2.5f, 6.5f},
-      {1.25, 1.25});
-  RunMomentsTest(
-      {2, 2, 2},
-      {0, 1},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
-      {4.0f, 5.0f},
-      {5.0f, 5.0f});
-  RunMomentsTest(
-      {2, 2, 2},
-      {0, 2},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
-      {3.5f, 5.5f},
-      {4.25, 4.25});
-}
-
 class TransposeGPUTest : public testing::Test {
  protected:
   void SetUp() override {
index 2e59cba..c4fec53 100644 (file)
@@ -742,105 +742,6 @@ TEST_F(RandFixedSumTest, UpperBound) {
       20, 1, 1000, 1000, l.data(), cpu_context_.get());
 }
 
-class MomentsTest : public testing::Test {
- protected:
-  void SetUp() override {
-    cpu_context_ = make_unique<CPUContext>(option_);
-  }
-
-  void RunMomentsTest(
-      const std::vector<int>& X_dims,
-      const std::vector<int>& axes,
-      const std::vector<float>& X_data,
-      const std::vector<float>& mean_data,
-      const std::vector<float>& variance_data) {
-    const int ndim = X_dims.size();
-    std::vector<int> Y_dims = X_dims;
-    for (const int axis : axes) {
-      Y_dims[axis] = 1;
-    }
-    std::vector<int64_t> X_dims_64;
-    std::vector<int64_t> Y_dims_64;
-    std::copy(X_dims.cbegin(), X_dims.cend(), std::back_inserter(X_dims_64));
-    std::copy(Y_dims.cbegin(), Y_dims.cend(), std::back_inserter(Y_dims_64));
-    ReinitializeTensor(&X_, X_dims_64, at::dtype<float>().device(CPU));
-    ReinitializeTensor(&mean_, Y_dims_64, at::dtype<float>().device(CPU));
-    ReinitializeTensor(&variance_, Y_dims_64, at::dtype<float>().device(CPU));
-    ASSERT_EQ(X_data.size(), X_.numel());
-    cpu_context_->CopyFromCPU<float>(
-        X_data.size(), X_data.data(), X_.mutable_data<float>());
-    math::Moments<float, CPUContext>(
-        X_dims.size(),
-        X_dims.data(),
-        axes.size(),
-        axes.data(),
-        X_.data<float>(),
-        mean_.mutable_data<float>(),
-        variance_.mutable_data<float>(),
-        cpu_context_.get());
-    ASSERT_EQ(mean_data.size(), mean_.numel());
-    for (int i = 0; i < mean_data.size(); ++i) {
-      EXPECT_FLOAT_EQ(mean_data[i], mean_.data<float>()[i]);
-    }
-    ASSERT_EQ(variance_data.size(), variance_.numel());
-    for (int i = 0; i < variance_data.size(); ++i) {
-      EXPECT_NEAR(variance_data[i], variance_.data<float>()[i], kEps);
-    }
-  }
-
-  DeviceOption option_;
-  std::unique_ptr<CPUContext> cpu_context_;
-
-  Tensor X_;
-  Tensor mean_;
-  Tensor variance_;
-};
-
-TEST_F(MomentsTest, MomentsFloatTest) {
-  // Test for 1D tensor.
-  RunMomentsTest({3}, {0}, {1.0f, 2.0f, 3.0f}, {2.0f}, {2.0f / 3.0f});
-
-  // Test for 2D Tensor.
-  RunMomentsTest(
-      {2, 3},
-      {1},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
-      {2.0f, 5.0f},
-      {2.0f / 3.0f, 2.0f / 3.0f});
-  RunMomentsTest(
-      {2, 3},
-      {0},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
-      {2.5f, 3.5f, 4.5f},
-      {2.25f, 2.25f, 2.25f});
-  RunMomentsTest(
-      {2, 3},
-      {0, 1},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
-      {3.5f},
-      {35.0f / 12.0f});
-
-  // Test for 3D tensor.
-  RunMomentsTest(
-      {2, 2, 2},
-      {1, 2},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
-      {2.5f, 6.5f},
-      {1.25, 1.25});
-  RunMomentsTest(
-      {2, 2, 2},
-      {0, 1},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
-      {4.0f, 5.0f},
-      {5.0f, 5.0f});
-  RunMomentsTest(
-      {2, 2, 2},
-      {0, 2},
-      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f},
-      {3.5f, 5.5f},
-      {4.25, 4.25});
-}
-
 class TransposeTest : public testing::Test {
  protected:
   void SetUp() override {
index a5513cd..f4f7da9 100644 (file)
@@ -41,6 +41,15 @@ bool IsIdentityPermutation(const int n, const int* perm) {
   return true;
 }
 
+bool CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims) {
+  for (int i = 0; i < ndim; ++i) {
+    if (X_dims[i] != Y_dims[i] && Y_dims[i] != 1) {
+      return false;
+    }
+  }
+  return true;
+}
+
 bool IsRowwiseReduce(
     const int ndim,
     const int* A_dims,
@@ -301,6 +310,22 @@ void ComputeTransposeAxesForReduceOp(
   }
 }
 
+void ComputeTransposeAxesForReduceOp(
+    const int ndim,
+    const int* dims,
+    int* axes) {
+  const int d = ndim - std::count(dims, dims + ndim, 1);
+  int p = 0;
+  int q = d;
+  for (int i = 0; i < ndim; ++i) {
+    if (dims[i] == 1) {
+      axes[q++] = i;
+    } else {
+      axes[p++] = i;
+    }
+  }
+}
+
 void ComputeTransposedStrides(
     const int ndim,
     const int* dims,
index bd53eb1..b7dfc6b 100644 (file)
@@ -64,6 +64,9 @@ CAFFE2_API int GetIndexFromDims(const int n, const int* dims, const int* index);
 // Checks if the input permutation is an identity permutation;
 CAFFE2_API bool IsIdentityPermutation(const int n, const int* perm);
 
+CAFFE2_API bool
+CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims);
+
 CAFFE2_API bool IsRowwiseReduce(
     const int ndim,
     const int* X_dims,
@@ -129,6 +132,9 @@ CAFFE2_API void ComputeTransposeAxesForReduceOp(
     const int* reduce_axes,
     int* transpose_axes);
 
+CAFFE2_API void
+ComputeTransposeAxesForReduceOp(const int ndim, const int* dims, int* axes);
+
 CAFFE2_API void ComputeTransposedStrides(
     const int ndim,
     const int* dims,