Optimize LayerNormOp (#17604)
authorXiaomeng Yang <yangxm@fb.com>
Sat, 9 Mar 2019 01:35:17 +0000 (17:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 9 Mar 2019 01:38:14 +0000 (17:38 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17604

Optimize LayerNormOp

i-am-not-moving-c2-to-c10

Reviewed By: houseroad

Differential Revision: D14274175

fbshipit-source-id: a7aa263a1b0eb109682d2be99306e7b2cdcc0faf

caffe2/operators/layer_norm_op.cc
caffe2/operators/layer_norm_op.cu
caffe2/operators/layer_norm_op.h
caffe2/utils/math/reduce.cc

index 2658901..b265ccd 100644 (file)
@@ -1,9 +1,11 @@
 #include "caffe2/operators/layer_norm_op.h"
-#include "caffe2/utils/eigen_utils.h"
-#include "caffe2/core/operator_c10wrapper.h"
+
 #include <ATen/core/dispatch/KernelRegistration.h>
-#include <c10/core/Tensor.h>
 #include <ATen/core/dispatch/OpSchemaRegistration.h>
+#include <c10/core/Tensor.h>
+
+#include "caffe2/core/operator_c10wrapper.h"
+#include "caffe2/utils/eigen_utils.h"
 
 namespace caffe2 {
 
@@ -56,9 +58,11 @@ void LayerNormGradientOp<CPUContext>::ComputeInternalGradients(
     T* ds,
     T* db) {
   ConstEigenArrayMap<T> dY_arr(dY, N, M);
-  EigenVectorArrayMap<T>(ds, M) =
-      (dY_arr * ConstEigenArrayMap<T>(X, N, M)).colwise().sum();
-  EigenVectorArrayMap<T>(db, M) = dY_arr.colwise().sum();
+  ConstEigenArrayMap<T> X_arr(X, N, M);
+  for (int i = 0; i < M; ++i) {
+    ds[i] = (dY_arr.col(i) * X_arr.col(i)).sum();
+    db[i] = dY_arr.col(i).sum();
+  }
 }
 
 template <>
@@ -185,15 +189,12 @@ to the end.)
 } // namespace caffe2
 
 C10_REGISTER_CAFFE2_OPERATOR_CPU(
-  LayerNorm,
-  (std::vector<c10::Argument>{
-    c10::Argument("input"),
-    c10::Argument("axis", c10::IntType::get()),
-    c10::Argument("epsilon", c10::FloatType::get())
-  }), (std::vector<c10::Argument>{
-    c10::Argument("output"),
-    c10::Argument("mean"),
-    c10::Argument("stdev")
-  }),
-  caffe2::LayerNormOp<caffe2::CPUContext>
-)
+    LayerNorm,
+    (std::vector<c10::Argument>{
+        c10::Argument("input"),
+        c10::Argument("axis", c10::IntType::get()),
+        c10::Argument("epsilon", c10::FloatType::get())}),
+    (std::vector<c10::Argument>{c10::Argument("output"),
+                                c10::Argument("mean"),
+                                c10::Argument("stdev")}),
+    caffe2::LayerNormOp<caffe2::CPUContext>)
index c87e267..1dfb783 100644 (file)
@@ -1,9 +1,8 @@
 #include "caffe2/operators/layer_norm_op.h"
 
-#include <cub/cub.cuh>
-
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/utils/math.h"
+#include "caffe2/utils/math/reduce.cuh"
 #include "caffe2/utils/math/utils.h"
 
 namespace caffe2 {
@@ -11,9 +10,6 @@ namespace caffe2 {
 namespace {
 
 template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T>
 __global__ void ComputeStdDevAndFusedParamsCUDAKernel(
     const int N,
     const T epsilon,
@@ -32,17 +28,18 @@ __global__ void ComputeStdDevAndFusedParamsCUDAKernel<float>(
     float* stddev,
     float* scale,
     float* bias) {
-  CUDA_1D_KERNEL_LOOP(i, N) {
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < N) {
 #if __CUDA_ARCH__ >= 350
-    const float rstd = rsqrtf(__ldg(var + i) + epsilon);
-    stddev[i] = rstd * (__ldg(var + i) + epsilon);
-    scale[i] = rstd;
-    bias[i] = -rstd * __ldg(mean + i);
+    const float rstd = rsqrtf(__ldg(var + index) + epsilon);
+    stddev[index] = rstd * (__ldg(var + index) + epsilon);
+    scale[index] = rstd;
+    bias[index] = -rstd * __ldg(mean + index);
 #else
-    const float rstd = rsqrtf(var[i] + epsilon);
-    stddev[i] = rstd * (var[i] + epsilon);
-    scale[i] = rstd;
-    bias[i] = -rstd * mean[i];
+    const float rstd = rsqrtf(var[index] + epsilon);
+    stddev[index] = rstd * (var[index] + epsilon);
+    scale[index] = rstd;
+    bias[index] = -rstd * mean[index];
 #endif
   }
 }
@@ -54,23 +51,25 @@ __global__ void LayerNormForwardCUDAKernel(
     const T* X,
     const T* scale,
     const T* bias,
-    T* Y) {
-  for (int i = blockIdx.x; i < M; i += gridDim.x) {
-#if __CUDA_ARCH__ >= 350
-    const float scale_val = __ldg(scale + i);
-    const float bias_val = __ldg(bias + i);
-#else
-    const float scale_val = scale[i];
-    const float bias_val = bias[i];
-#endif
-    for (int j = threadIdx.x; j < N; j += blockDim.x) {
-      const int index = i * N + j;
+    T* Y);
+
+template <>
+__global__ void LayerNormForwardCUDAKernel<float>(
+    const int M,
+    const int N,
+    const float* X,
+    const float* scale,
+    const float* bias,
+    float* Y) {
+  const int size = M * N;
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < size) {
+    const int i = index / N;
 #if __CUDA_ARCH__ >= 350
-      Y[index] = __ldg(X + index) * scale_val + bias_val;
+    Y[index] = fmaf(__ldg(X + index), __ldg(scale + i), __ldg(bias + i));
 #else
-      Y[index] = X[index] * scale_val + bias_val;
+    Y[index] = fmaf(X[index], scale[i], bias[i]);
 #endif
-    }
   }
 }
 
@@ -84,26 +83,24 @@ __global__ void ComputeInternalGradientsCUDAKernel(
     T* db) {
   __shared__ typename BlockReduce<T>::TempStorage ds_storage;
   __shared__ typename BlockReduce<T>::TempStorage db_storage;
-  for (int i = blockIdx.x; i < M; i += gridDim.x) {
-    T ds_val = 0;
-    T db_val = 0;
-    for (int j = threadIdx.x; j < N; j += blockDim.x) {
-      const int index = i * N + j;
+  const int i = blockIdx.x;
+  T ds_val = 0;
+  T db_val = 0;
+  for (int j = threadIdx.x; j < N; j += blockDim.x) {
+    const int index = i * N + j;
 #if __CUDA_ARCH__ >= 350
-      ds_val += __ldg(dY + index) * __ldg(X + index);
-      db_val += __ldg(dY + index);
+    ds_val += __ldg(dY + index) * __ldg(X + index);
+    db_val += __ldg(dY + index);
 #else
-      ds_val += dY[index] * X[index];
-      db_val += dY[index];
+    ds_val += dY[index] * X[index];
+    db_val += dY[index];
 #endif
-    }
-    ds_val = BlockReduce<T>(ds_storage).Sum(ds_val);
-    db_val = BlockReduce<T>(db_storage).Sum(db_val);
-    if (threadIdx.x == 0) {
-      ds[i] = ds_val;
-      db[i] = db_val;
-    }
-    __syncthreads();
+  }
+  ds_val = BlockReduce<T>(ds_storage).Sum(ds_val);
+  db_val = BlockReduce<T>(db_storage).Sum(db_val);
+  if (threadIdx.x == 0) {
+    ds[i] = ds_val;
+    db[i] = db_val;
   }
 }
 
@@ -117,23 +114,38 @@ __global__ void ComputeFusedParamsCUDAKernel(
     const T* db,
     T* dY_scale,
     T* X_scale,
-    T* bias) {
-  const T scale = T(1) / static_cast<T>(N);
-  CUDA_1D_KERNEL_LOOP(i, M) {
+    T* bias);
+
+template <>
+__global__ void ComputeFusedParamsCUDAKernel<float>(
+    const int M,
+    const int N,
+    const float* mean,
+    const float* sig,
+    const float* ds,
+    const float* db,
+    float* dY_scale,
+    float* X_scale,
+    float* bias) {
+  const float scale = 1.0f / static_cast<float>(N);
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < M) {
 #if __CUDA_ARCH__ >= 350
-    const T rsig = T(1) / __ldg(sig + i);
-    const T X_scale_val = (__ldg(db + i) * __ldg(mean + i) - __ldg(ds + i)) *
-        math::utils::Cube<T>(rsig) * scale;
-    dY_scale[i] = rsig;
-    X_scale[i] = X_scale_val;
-    bias[i] = -X_scale_val * __ldg(mean + i) - __ldg(db + i) * rsig * scale;
+    const float rsig = 1.0f / __ldg(sig + index);
+    const float X_scale_val =
+        fmaf(__ldg(db + index), __ldg(mean + index), -__ldg(ds + index)) *
+        math::utils::Cube<float>(rsig) * scale;
+    dY_scale[index] = rsig;
+    X_scale[index] = X_scale_val;
+    bias[index] = -fmaf(
+        X_scale_val, __ldg(mean + index), __ldg(db + index) * rsig * scale);
 #else
-    const T rsig = T(1) / sig[i];
-    const T X_scale_val =
-        (db[i] * mean[i] - ds[i]) * math::utils::Cube<T>(rsig) * scale;
-    dY_scale[i] = rsig;
-    X_scale[i] = X_scale_val;
-    bias[i] = -X_scale_val * mean[i] - db[i] * rsig * scale;
+    const float rsig = 1.0f / sig[index];
+    const float X_scale_val = fmaf(db[index], mean[index], -ds[index]) *
+        math::utils::Cube<float>(rsig) * scale;
+    dY_scale[index] = rsig;
+    X_scale[index] = X_scale_val;
+    bias[index] = -fmaf(X_scale_val, mean[index], db[index] * rsig * scale);
 #endif
   }
 }
@@ -147,26 +159,31 @@ __global__ void LayerNormBackwardCUDAKenrel(
     const T* X_scale,
     const T* X,
     const T* bias,
-    T* dX) {
-  for (int i = blockIdx.x; i < M; i += gridDim.x) {
-#if __CUDA_ARCH__ >= 350
-    const float dY_scale_val = __ldg(dY_scale + i);
-    const float X_scale_val = __ldg(X_scale + i);
-    const float bias_val = __ldg(bias + i);
-#else
-    const float dY_scale_val = dY_scale[i];
-    const float X_scale_val = X_scale[i];
-    const float bias_val = bias[i];
-#endif
-    for (int j = threadIdx.x; j < N; j += blockDim.x) {
-      const int index = i * N + j;
+    T* dX);
+
+template <>
+__global__ void LayerNormBackwardCUDAKenrel<float>(
+    const int M,
+    const int N,
+    const float* dY_scale,
+    const float* dY,
+    const float* X_scale,
+    const float* X,
+    const float* bias,
+    float* dX) {
+  const int size = M * N;
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < size) {
+    const int i = index / N;
 #if __CUDA_ARCH__ >= 350
-      dX[index] = __ldg(dY + index) * dY_scale_val +
-          __ldg(X + index) * X_scale_val + bias_val;
+    dX[index] = fmaf(
+        __ldg(dY + index),
+        __ldg(dY_scale + i),
+        fmaf(__ldg(X + index), __ldg(X_scale + i), __ldg(bias + i)));
 #else
-      dX[index] = dY[index] * dY_scale_val + X[index] * X_scale_val + bias_val;
+    dX[index] =
+        fmaf(dY[index], dY_scale[i], fmaf(X[index], X_scale[i], bias[i]));
 #endif
-    }
   }
 }
 
@@ -183,11 +200,9 @@ void LayerNormOp<CUDAContext>::ComputeStdDevAndFusedParams(
     T* bias,
     float epsilon,
     CUDAContext* context) {
+  const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
   ComputeStdDevAndFusedParamsCUDAKernel<T>
-      <<<CAFFE_GET_BLOCKS(N),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context->cuda_stream()>>>(
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
           N, static_cast<T>(epsilon), mean, var, stddev, scale, bias);
 }
 
@@ -201,11 +216,10 @@ void LayerNormOp<CUDAContext>::LayerNormForward(
     const T* bias,
     T* Y,
     CUDAContext* context) {
+  const int K = math::DivUp(M * N, CAFFE_CUDA_NUM_THREADS);
   LayerNormForwardCUDAKernel<T>
-      <<<std::min(M, CAFFE_MAXIMUM_NUM_BLOCKS),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context->cuda_stream()>>>(M, N, X, scale, bias, Y);
+      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          M, N, X, scale, bias, Y);
 }
 
 REGISTER_CUDA_OPERATOR(LayerNorm, LayerNormOp<CUDAContext>);
@@ -220,10 +234,8 @@ void LayerNormGradientOp<CUDAContext>::ComputeInternalGradients(
     T* ds,
     T* db) {
   ComputeInternalGradientsCUDAKernel<T>
-      <<<std::min(M, CAFFE_MAXIMUM_NUM_BLOCKS),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context_.cuda_stream()>>>(M, N, dY, X, ds, db);
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          M, N, dY, X, ds, db);
 }
 
 template <>
@@ -238,11 +250,9 @@ void LayerNormGradientOp<CUDAContext>::ComputeFusedParams(
     T* dY_scale,
     T* X_scale,
     T* bias) {
+  const int K = math::DivUp(M, CAFFE_CUDA_NUM_THREADS);
   ComputeFusedParamsCUDAKernel<T>
-      <<<CAFFE_GET_BLOCKS(M),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context_.cuda_stream()>>>(
+      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
           M, N, mean, sig, ds, db, dY_scale, X_scale, bias);
 }
 
@@ -257,11 +267,10 @@ void LayerNormGradientOp<CUDAContext>::LayerNormBackward(
     const T* X,
     const T* bias,
     T* dX) {
+  const int K = math::DivUp(M * N, CAFFE_CUDA_NUM_THREADS);
   LayerNormBackwardCUDAKenrel<T>
-      <<<std::min(M, CAFFE_MAXIMUM_NUM_BLOCKS),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context_.cuda_stream()>>>(M, N, dY_scale, dY, X_scale, X, bias, dX);
+      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          M, N, dY_scale, dY, X_scale, X, bias, dX);
 }
 
 REGISTER_CUDA_OPERATOR(LayerNormGradient, LayerNormGradientOp<CUDAContext>);
index 2640f51..a34a082 100644 (file)
@@ -4,11 +4,12 @@
 #include <array>
 #include <vector>
 
+#include <ATen/core/dispatch/OpSchemaRegistration.h>
+
 #include "caffe2/core/context.h"
 #include "caffe2/core/operator.h"
 #include "caffe2/core/types.h"
 #include "caffe2/utils/math.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
 
 C10_DECLARE_CAFFE2_OPERATOR(LayerNorm)
 
@@ -19,14 +20,12 @@ class LayerNormOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
 
-  template<class... Args>
+  template <class... Args>
   explicit LayerNormOp(Args&&... args)
       : Operator<Context>(std::forward<Args>(args)...),
         OP_SINGLE_ARG(int, "axis", axis_, 1),
         OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f) {}
 
-  ~LayerNormOp() {}
-
   bool RunOnDevice() override {
     return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
   }
@@ -41,19 +40,19 @@ class LayerNormOp final : public Operator<Context> {
     moments_dims.push_back(1);
     auto* mean = Output(1, moments_dims, at::dtype<T>());
     auto* sig = Output(2, moments_dims, at::dtype<T>());
-    runLayerNorm<T>(
-        X, Y, mean, sig, canonical_axis, epsilon_, &scale_, &bias_, &context_);
+    RunLayerNorm<T>(
+        X, canonical_axis, epsilon_, Y, mean, sig, &scale_, &bias_, &context_);
     return true;
   }
 
   template <typename T>
-  static void runLayerNorm(
+  static void RunLayerNorm(
       const Tensor& X,
+      const int canonical_axis,
+      const float epsilon,
       Tensor* Y,
       Tensor* mean,
       Tensor* sig,
-      int canonical_axis,
-      float epsilon,
       Tensor* scale_buffer,
       Tensor* bias_buffer,
       Context* context) {
@@ -63,14 +62,12 @@ class LayerNormOp final : public Operator<Context> {
     Y->ResizeLike(X);
     scale_buffer->Resize(M);
     bias_buffer->Resize(M);
-
     const T* X_data = X.template data<T>();
     T* Y_data = Y->template mutable_data<T>();
     T* mean_data = mean->template mutable_data<T>();
     T* sig_data = sig->template mutable_data<T>();
     T* scale_data = scale_buffer->template mutable_data<T>();
     T* bias_data = bias_buffer->template mutable_data<T>();
-
     const std::array<int, 2> X_dims = {M, N};
     const std::array<int, 2> Y_dims = {M, 1};
     math::Moments<T, Context>(
index bf13634..09824df 100644 (file)
@@ -24,17 +24,20 @@ namespace math {
 
 namespace {
 
-#define DELEGATE_ROWWISE_REDUCE_FUNCTION(Func, EigenFunc)                    \
-  template <typename T>                                                      \
-  void Rowwise##Func(                                                        \
-      const int rows,                                                        \
-      const int cols,                                                        \
-      const T alpha,                                                         \
-      const T* X,                                                            \
-      T* Y,                                                                  \
-      CPUContext* /* context */) {                                           \
-    EigenVectorMap<T>(Y, rows) =                                             \
-        ConstEigenMatrixMap<T>(X, cols, rows).colwise().EigenFunc() * alpha; \
+#define DELEGATE_ROWWISE_REDUCE_FUNCTION(Func, EigenFunc)              \
+  template <typename T>                                                \
+  void Rowwise##Func(                                                  \
+      const int rows,                                                  \
+      const int cols,                                                  \
+      const T alpha,                                                   \
+      const T* X,                                                      \
+      T* Y,                                                            \
+      CPUContext* /* context */) {                                     \
+    EigenVectorMap<T>(Y, rows) = ConstEigenMatrixMap<T>(X, cols, rows) \
+                                     .colwise()                        \
+                                     .EigenFunc()                      \
+                                     .transpose() *                    \
+        alpha;                                                         \
   }
 DELEGATE_ROWWISE_REDUCE_FUNCTION(ReduceMin, minCoeff)
 DELEGATE_ROWWISE_REDUCE_FUNCTION(ReduceMax, maxCoeff)
@@ -184,7 +187,8 @@ void BothEndsReduceSum(
   EigenVectorArrayMap<T> Y_arr(Y, N);
   Y_arr = ConstEigenArrayMap<T>(X, K, N).colwise().sum();
   for (int i = 1; i < M; ++i) {
-    Y_arr += ConstEigenArrayMap<T>(X + i * N * K, K, N).colwise().sum();
+    Y_arr +=
+        ConstEigenArrayMap<T>(X + i * N * K, K, N).colwise().sum().transpose();
   }
   Scale<T, T, CPUContext>(N, alpha, Y, Y, context);
 }
@@ -199,11 +203,12 @@ void BothEndsReduceMean(
     T* Y,
     CPUContext* context) {
   EigenVectorArrayMap<T> Y_arr(Y, N);
-  Y_arr = ConstEigenArrayMap<T>(X, K, N).colwise().mean();
+  Y_arr = ConstEigenArrayMap<T>(X, K, N).colwise().sum();
   for (int i = 1; i < M; ++i) {
-    Y_arr += ConstEigenArrayMap<T>(X + i * N * K, K, N).colwise().mean();
+    Y_arr +=
+        ConstEigenArrayMap<T>(X + i * N * K, K, N).colwise().sum().transpose();
   }
-  Scale<T, T, CPUContext>(N, alpha / static_cast<T>(M), Y, Y, context);
+  Scale<T, T, CPUContext>(N, alpha / static_cast<T>(M * K), Y, Y, context);
 }
 
 template <typename T>
@@ -220,7 +225,8 @@ void BothEndsReduceL1(
   for (int i = 1; i < M; ++i) {
     Y_vec += ConstEigenMatrixMap<T>(X + i * N * K, K, N)
                  .colwise()
-                 .template lpNorm<1>();
+                 .template lpNorm<1>()
+                 .transpose();
   }
   Scale<T, T, CPUContext>(N, alpha, Y, Y, context);
 }
@@ -234,13 +240,18 @@ void BothEndsReduceL2(
     const T* X,
     T* Y,
     CPUContext* /* context */) {
-  EigenVectorMap<T> Y_vec(Y, N);
-  Y_vec = ConstEigenMatrixMap<T>(X, K, N).colwise().squaredNorm();
+  ConstEigenArrayMap<T> X0_arr(X, K, N);
+  EigenVectorArrayMap<T> Y_arr(Y, N);
+  for (int i = 0; i < N; ++i) {
+    Y_arr(i) = X0_arr.col(i).square().sum();
+  }
   for (int i = 1; i < M; ++i) {
-    Y_vec +=
-        ConstEigenMatrixMap<T>(X + i * N * K, K, N).colwise().squaredNorm();
+    ConstEigenArrayMap<T> Xi_arr(X + i * N * K, K, N);
+    for (int j = 0; j < N; ++j) {
+      Y_arr(j) += Xi_arr.col(j).square().sum();
+    }
   }
-  Y_vec = Y_vec.cwiseSqrt() * alpha;
+  Y_arr = Y_arr.sqrt() * alpha;
 }
 
 template <typename T, class Reducer>
@@ -404,10 +415,10 @@ void RowwiseMoments(
     T* mean,
     T* var) {
   ConstEigenArrayMap<T> X_arr(X, cols, rows);
-  EigenVectorArrayMap<T> mean_arr(mean, rows);
-  EigenVectorArrayMap<T> var_arr(var, rows);
-  mean_arr = X_arr.colwise().mean();
-  var_arr = X_arr.square().colwise().mean() - mean_arr.square().transpose();
+  for (int i = 0; i < rows; ++i) {
+    mean[i] = X_arr.col(i).mean();
+    var[i] = X_arr.col(i).square().mean() - mean[i] * mean[i];
+  }
 }
 
 template <typename T>
@@ -420,7 +431,6 @@ void ColwiseMoments(
   ConstEigenArrayMap<T> X_arr(X, cols, rows);
   EigenVectorArrayMap<T> mean_arr(mean, cols);
   EigenVectorArrayMap<T> var_arr(var, cols);
-  // Eigen rowwise reduction is about 10 times slower than this for-loop.
   mean_arr = X_arr.col(0);
   var_arr = X_arr.col(0).square();
   for (int i = 1; i < rows; ++i) {
@@ -440,15 +450,19 @@ void BothEndsMoments(
     const T* X,
     T* mean,
     T* var) {
+  ConstEigenArrayMap<T> X_arr(X, K, M * N);
   EigenVectorArrayMap<T> mean_arr(mean, N);
   EigenVectorArrayMap<T> var_arr(var, N);
-  ConstEigenArrayMap<T> X0_arr(X, K, N);
-  mean_arr = X0_arr.colwise().sum();
-  var_arr = X0_arr.square().colwise().sum();
+  for (int i = 0; i < N; ++i) {
+    mean_arr(i) = X_arr.col(i).sum();
+    var_arr(i) = X_arr.col(i).square().sum();
+  }
   for (int i = 1; i < M; ++i) {
-    ConstEigenArrayMap<T> X_arr(X + i * N * K, K, N);
-    mean_arr += X_arr.colwise().sum();
-    var_arr += X_arr.square().colwise().sum();
+    for (int j = 0; j < N; ++j) {
+      const int c = i * N + j;
+      mean_arr(j) += X_arr.col(c).sum();
+      var_arr(j) += X_arr.col(c).square().sum();
+    }
   }
   const T scale = T(1) / static_cast<T>(M * K);
   mean_arr *= scale;