From 0fd1dc45c037ae806352e08a7b83eeacfad101a0 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Fri, 8 Mar 2019 17:35:17 -0800 Subject: [PATCH] Optimize LayerNormOp (#17604) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17604 Optimize LayerNormOp i-am-not-moving-c2-to-c10 Reviewed By: houseroad Differential Revision: D14274175 fbshipit-source-id: a7aa263a1b0eb109682d2be99306e7b2cdcc0faf --- caffe2/operators/layer_norm_op.cc | 37 +++---- caffe2/operators/layer_norm_op.cu | 203 ++++++++++++++++++++------------------ caffe2/operators/layer_norm_op.h | 19 ++-- caffe2/utils/math/reduce.cc | 78 +++++++++------ 4 files changed, 179 insertions(+), 158 deletions(-) diff --git a/caffe2/operators/layer_norm_op.cc b/caffe2/operators/layer_norm_op.cc index 2658901..b265ccd 100644 --- a/caffe2/operators/layer_norm_op.cc +++ b/caffe2/operators/layer_norm_op.cc @@ -1,9 +1,11 @@ #include "caffe2/operators/layer_norm_op.h" -#include "caffe2/utils/eigen_utils.h" -#include "caffe2/core/operator_c10wrapper.h" + #include -#include #include +#include + +#include "caffe2/core/operator_c10wrapper.h" +#include "caffe2/utils/eigen_utils.h" namespace caffe2 { @@ -56,9 +58,11 @@ void LayerNormGradientOp::ComputeInternalGradients( T* ds, T* db) { ConstEigenArrayMap dY_arr(dY, N, M); - EigenVectorArrayMap(ds, M) = - (dY_arr * ConstEigenArrayMap(X, N, M)).colwise().sum(); - EigenVectorArrayMap(db, M) = dY_arr.colwise().sum(); + ConstEigenArrayMap X_arr(X, N, M); + for (int i = 0; i < M; ++i) { + ds[i] = (dY_arr.col(i) * X_arr.col(i)).sum(); + db[i] = dY_arr.col(i).sum(); + } } template <> @@ -185,15 +189,12 @@ to the end.) } // namespace caffe2 C10_REGISTER_CAFFE2_OPERATOR_CPU( - LayerNorm, - (std::vector{ - c10::Argument("input"), - c10::Argument("axis", c10::IntType::get()), - c10::Argument("epsilon", c10::FloatType::get()) - }), (std::vector{ - c10::Argument("output"), - c10::Argument("mean"), - c10::Argument("stdev") - }), - caffe2::LayerNormOp -) + LayerNorm, + (std::vector{ + c10::Argument("input"), + c10::Argument("axis", c10::IntType::get()), + c10::Argument("epsilon", c10::FloatType::get())}), + (std::vector{c10::Argument("output"), + c10::Argument("mean"), + c10::Argument("stdev")}), + caffe2::LayerNormOp) diff --git a/caffe2/operators/layer_norm_op.cu b/caffe2/operators/layer_norm_op.cu index c87e267..1dfb783 100644 --- a/caffe2/operators/layer_norm_op.cu +++ b/caffe2/operators/layer_norm_op.cu @@ -1,9 +1,8 @@ #include "caffe2/operators/layer_norm_op.h" -#include - #include "caffe2/core/context_gpu.h" #include "caffe2/utils/math.h" +#include "caffe2/utils/math/reduce.cuh" #include "caffe2/utils/math/utils.h" namespace caffe2 { @@ -11,9 +10,6 @@ namespace caffe2 { namespace { template -using BlockReduce = cub::BlockReduce; - -template __global__ void ComputeStdDevAndFusedParamsCUDAKernel( const int N, const T epsilon, @@ -32,17 +28,18 @@ __global__ void ComputeStdDevAndFusedParamsCUDAKernel( float* stddev, float* scale, float* bias) { - CUDA_1D_KERNEL_LOOP(i, N) { + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < N) { #if __CUDA_ARCH__ >= 350 - const float rstd = rsqrtf(__ldg(var + i) + epsilon); - stddev[i] = rstd * (__ldg(var + i) + epsilon); - scale[i] = rstd; - bias[i] = -rstd * __ldg(mean + i); + const float rstd = rsqrtf(__ldg(var + index) + epsilon); + stddev[index] = rstd * (__ldg(var + index) + epsilon); + scale[index] = rstd; + bias[index] = -rstd * __ldg(mean + index); #else - const float rstd = rsqrtf(var[i] + epsilon); - stddev[i] = rstd * (var[i] + epsilon); - scale[i] = rstd; - bias[i] = -rstd * mean[i]; + const float rstd = rsqrtf(var[index] + epsilon); + stddev[index] = rstd * (var[index] + epsilon); + scale[index] = rstd; + bias[index] = -rstd * mean[index]; #endif } } @@ -54,23 +51,25 @@ __global__ void LayerNormForwardCUDAKernel( const T* X, const T* scale, const T* bias, - T* Y) { - for (int i = blockIdx.x; i < M; i += gridDim.x) { -#if __CUDA_ARCH__ >= 350 - const float scale_val = __ldg(scale + i); - const float bias_val = __ldg(bias + i); -#else - const float scale_val = scale[i]; - const float bias_val = bias[i]; -#endif - for (int j = threadIdx.x; j < N; j += blockDim.x) { - const int index = i * N + j; + T* Y); + +template <> +__global__ void LayerNormForwardCUDAKernel( + const int M, + const int N, + const float* X, + const float* scale, + const float* bias, + float* Y) { + const int size = M * N; + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < size) { + const int i = index / N; #if __CUDA_ARCH__ >= 350 - Y[index] = __ldg(X + index) * scale_val + bias_val; + Y[index] = fmaf(__ldg(X + index), __ldg(scale + i), __ldg(bias + i)); #else - Y[index] = X[index] * scale_val + bias_val; + Y[index] = fmaf(X[index], scale[i], bias[i]); #endif - } } } @@ -84,26 +83,24 @@ __global__ void ComputeInternalGradientsCUDAKernel( T* db) { __shared__ typename BlockReduce::TempStorage ds_storage; __shared__ typename BlockReduce::TempStorage db_storage; - for (int i = blockIdx.x; i < M; i += gridDim.x) { - T ds_val = 0; - T db_val = 0; - for (int j = threadIdx.x; j < N; j += blockDim.x) { - const int index = i * N + j; + const int i = blockIdx.x; + T ds_val = 0; + T db_val = 0; + for (int j = threadIdx.x; j < N; j += blockDim.x) { + const int index = i * N + j; #if __CUDA_ARCH__ >= 350 - ds_val += __ldg(dY + index) * __ldg(X + index); - db_val += __ldg(dY + index); + ds_val += __ldg(dY + index) * __ldg(X + index); + db_val += __ldg(dY + index); #else - ds_val += dY[index] * X[index]; - db_val += dY[index]; + ds_val += dY[index] * X[index]; + db_val += dY[index]; #endif - } - ds_val = BlockReduce(ds_storage).Sum(ds_val); - db_val = BlockReduce(db_storage).Sum(db_val); - if (threadIdx.x == 0) { - ds[i] = ds_val; - db[i] = db_val; - } - __syncthreads(); + } + ds_val = BlockReduce(ds_storage).Sum(ds_val); + db_val = BlockReduce(db_storage).Sum(db_val); + if (threadIdx.x == 0) { + ds[i] = ds_val; + db[i] = db_val; } } @@ -117,23 +114,38 @@ __global__ void ComputeFusedParamsCUDAKernel( const T* db, T* dY_scale, T* X_scale, - T* bias) { - const T scale = T(1) / static_cast(N); - CUDA_1D_KERNEL_LOOP(i, M) { + T* bias); + +template <> +__global__ void ComputeFusedParamsCUDAKernel( + const int M, + const int N, + const float* mean, + const float* sig, + const float* ds, + const float* db, + float* dY_scale, + float* X_scale, + float* bias) { + const float scale = 1.0f / static_cast(N); + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < M) { #if __CUDA_ARCH__ >= 350 - const T rsig = T(1) / __ldg(sig + i); - const T X_scale_val = (__ldg(db + i) * __ldg(mean + i) - __ldg(ds + i)) * - math::utils::Cube(rsig) * scale; - dY_scale[i] = rsig; - X_scale[i] = X_scale_val; - bias[i] = -X_scale_val * __ldg(mean + i) - __ldg(db + i) * rsig * scale; + const float rsig = 1.0f / __ldg(sig + index); + const float X_scale_val = + fmaf(__ldg(db + index), __ldg(mean + index), -__ldg(ds + index)) * + math::utils::Cube(rsig) * scale; + dY_scale[index] = rsig; + X_scale[index] = X_scale_val; + bias[index] = -fmaf( + X_scale_val, __ldg(mean + index), __ldg(db + index) * rsig * scale); #else - const T rsig = T(1) / sig[i]; - const T X_scale_val = - (db[i] * mean[i] - ds[i]) * math::utils::Cube(rsig) * scale; - dY_scale[i] = rsig; - X_scale[i] = X_scale_val; - bias[i] = -X_scale_val * mean[i] - db[i] * rsig * scale; + const float rsig = 1.0f / sig[index]; + const float X_scale_val = fmaf(db[index], mean[index], -ds[index]) * + math::utils::Cube(rsig) * scale; + dY_scale[index] = rsig; + X_scale[index] = X_scale_val; + bias[index] = -fmaf(X_scale_val, mean[index], db[index] * rsig * scale); #endif } } @@ -147,26 +159,31 @@ __global__ void LayerNormBackwardCUDAKenrel( const T* X_scale, const T* X, const T* bias, - T* dX) { - for (int i = blockIdx.x; i < M; i += gridDim.x) { -#if __CUDA_ARCH__ >= 350 - const float dY_scale_val = __ldg(dY_scale + i); - const float X_scale_val = __ldg(X_scale + i); - const float bias_val = __ldg(bias + i); -#else - const float dY_scale_val = dY_scale[i]; - const float X_scale_val = X_scale[i]; - const float bias_val = bias[i]; -#endif - for (int j = threadIdx.x; j < N; j += blockDim.x) { - const int index = i * N + j; + T* dX); + +template <> +__global__ void LayerNormBackwardCUDAKenrel( + const int M, + const int N, + const float* dY_scale, + const float* dY, + const float* X_scale, + const float* X, + const float* bias, + float* dX) { + const int size = M * N; + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < size) { + const int i = index / N; #if __CUDA_ARCH__ >= 350 - dX[index] = __ldg(dY + index) * dY_scale_val + - __ldg(X + index) * X_scale_val + bias_val; + dX[index] = fmaf( + __ldg(dY + index), + __ldg(dY_scale + i), + fmaf(__ldg(X + index), __ldg(X_scale + i), __ldg(bias + i))); #else - dX[index] = dY[index] * dY_scale_val + X[index] * X_scale_val + bias_val; + dX[index] = + fmaf(dY[index], dY_scale[i], fmaf(X[index], X_scale[i], bias[i])); #endif - } } } @@ -183,11 +200,9 @@ void LayerNormOp::ComputeStdDevAndFusedParams( T* bias, float epsilon, CUDAContext* context) { + const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS); ComputeStdDevAndFusedParamsCUDAKernel - <<cuda_stream()>>>( + <<cuda_stream()>>>( N, static_cast(epsilon), mean, var, stddev, scale, bias); } @@ -201,11 +216,10 @@ void LayerNormOp::LayerNormForward( const T* bias, T* Y, CUDAContext* context) { + const int K = math::DivUp(M * N, CAFFE_CUDA_NUM_THREADS); LayerNormForwardCUDAKernel - <<cuda_stream()>>>(M, N, X, scale, bias, Y); + <<cuda_stream()>>>( + M, N, X, scale, bias, Y); } REGISTER_CUDA_OPERATOR(LayerNorm, LayerNormOp); @@ -220,10 +234,8 @@ void LayerNormGradientOp::ComputeInternalGradients( T* ds, T* db) { ComputeInternalGradientsCUDAKernel - <<>>(M, N, dY, X, ds, db); + <<>>( + M, N, dY, X, ds, db); } template <> @@ -238,11 +250,9 @@ void LayerNormGradientOp::ComputeFusedParams( T* dY_scale, T* X_scale, T* bias) { + const int K = math::DivUp(M, CAFFE_CUDA_NUM_THREADS); ComputeFusedParamsCUDAKernel - <<>>( + <<>>( M, N, mean, sig, ds, db, dY_scale, X_scale, bias); } @@ -257,11 +267,10 @@ void LayerNormGradientOp::LayerNormBackward( const T* X, const T* bias, T* dX) { + const int K = math::DivUp(M * N, CAFFE_CUDA_NUM_THREADS); LayerNormBackwardCUDAKenrel - <<>>(M, N, dY_scale, dY, X_scale, X, bias, dX); + <<>>( + M, N, dY_scale, dY, X_scale, X, bias, dX); } REGISTER_CUDA_OPERATOR(LayerNormGradient, LayerNormGradientOp); diff --git a/caffe2/operators/layer_norm_op.h b/caffe2/operators/layer_norm_op.h index 2640f51..a34a082 100644 --- a/caffe2/operators/layer_norm_op.h +++ b/caffe2/operators/layer_norm_op.h @@ -4,11 +4,12 @@ #include #include +#include + #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" #include "caffe2/utils/math.h" -#include C10_DECLARE_CAFFE2_OPERATOR(LayerNorm) @@ -19,14 +20,12 @@ class LayerNormOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - template + template explicit LayerNormOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int, "axis", axis_, 1), OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f) {} - ~LayerNormOp() {} - bool RunOnDevice() override { return DispatchHelper>::call(this, Input(0)); } @@ -41,19 +40,19 @@ class LayerNormOp final : public Operator { moments_dims.push_back(1); auto* mean = Output(1, moments_dims, at::dtype()); auto* sig = Output(2, moments_dims, at::dtype()); - runLayerNorm( - X, Y, mean, sig, canonical_axis, epsilon_, &scale_, &bias_, &context_); + RunLayerNorm( + X, canonical_axis, epsilon_, Y, mean, sig, &scale_, &bias_, &context_); return true; } template - static void runLayerNorm( + static void RunLayerNorm( const Tensor& X, + const int canonical_axis, + const float epsilon, Tensor* Y, Tensor* mean, Tensor* sig, - int canonical_axis, - float epsilon, Tensor* scale_buffer, Tensor* bias_buffer, Context* context) { @@ -63,14 +62,12 @@ class LayerNormOp final : public Operator { Y->ResizeLike(X); scale_buffer->Resize(M); bias_buffer->Resize(M); - const T* X_data = X.template data(); T* Y_data = Y->template mutable_data(); T* mean_data = mean->template mutable_data(); T* sig_data = sig->template mutable_data(); T* scale_data = scale_buffer->template mutable_data(); T* bias_data = bias_buffer->template mutable_data(); - const std::array X_dims = {M, N}; const std::array Y_dims = {M, 1}; math::Moments( diff --git a/caffe2/utils/math/reduce.cc b/caffe2/utils/math/reduce.cc index bf13634..09824df 100644 --- a/caffe2/utils/math/reduce.cc +++ b/caffe2/utils/math/reduce.cc @@ -24,17 +24,20 @@ namespace math { namespace { -#define DELEGATE_ROWWISE_REDUCE_FUNCTION(Func, EigenFunc) \ - template \ - void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CPUContext* /* context */) { \ - EigenVectorMap(Y, rows) = \ - ConstEigenMatrixMap(X, cols, rows).colwise().EigenFunc() * alpha; \ +#define DELEGATE_ROWWISE_REDUCE_FUNCTION(Func, EigenFunc) \ + template \ + void Rowwise##Func( \ + const int rows, \ + const int cols, \ + const T alpha, \ + const T* X, \ + T* Y, \ + CPUContext* /* context */) { \ + EigenVectorMap(Y, rows) = ConstEigenMatrixMap(X, cols, rows) \ + .colwise() \ + .EigenFunc() \ + .transpose() * \ + alpha; \ } DELEGATE_ROWWISE_REDUCE_FUNCTION(ReduceMin, minCoeff) DELEGATE_ROWWISE_REDUCE_FUNCTION(ReduceMax, maxCoeff) @@ -184,7 +187,8 @@ void BothEndsReduceSum( EigenVectorArrayMap Y_arr(Y, N); Y_arr = ConstEigenArrayMap(X, K, N).colwise().sum(); for (int i = 1; i < M; ++i) { - Y_arr += ConstEigenArrayMap(X + i * N * K, K, N).colwise().sum(); + Y_arr += + ConstEigenArrayMap(X + i * N * K, K, N).colwise().sum().transpose(); } Scale(N, alpha, Y, Y, context); } @@ -199,11 +203,12 @@ void BothEndsReduceMean( T* Y, CPUContext* context) { EigenVectorArrayMap Y_arr(Y, N); - Y_arr = ConstEigenArrayMap(X, K, N).colwise().mean(); + Y_arr = ConstEigenArrayMap(X, K, N).colwise().sum(); for (int i = 1; i < M; ++i) { - Y_arr += ConstEigenArrayMap(X + i * N * K, K, N).colwise().mean(); + Y_arr += + ConstEigenArrayMap(X + i * N * K, K, N).colwise().sum().transpose(); } - Scale(N, alpha / static_cast(M), Y, Y, context); + Scale(N, alpha / static_cast(M * K), Y, Y, context); } template @@ -220,7 +225,8 @@ void BothEndsReduceL1( for (int i = 1; i < M; ++i) { Y_vec += ConstEigenMatrixMap(X + i * N * K, K, N) .colwise() - .template lpNorm<1>(); + .template lpNorm<1>() + .transpose(); } Scale(N, alpha, Y, Y, context); } @@ -234,13 +240,18 @@ void BothEndsReduceL2( const T* X, T* Y, CPUContext* /* context */) { - EigenVectorMap Y_vec(Y, N); - Y_vec = ConstEigenMatrixMap(X, K, N).colwise().squaredNorm(); + ConstEigenArrayMap X0_arr(X, K, N); + EigenVectorArrayMap Y_arr(Y, N); + for (int i = 0; i < N; ++i) { + Y_arr(i) = X0_arr.col(i).square().sum(); + } for (int i = 1; i < M; ++i) { - Y_vec += - ConstEigenMatrixMap(X + i * N * K, K, N).colwise().squaredNorm(); + ConstEigenArrayMap Xi_arr(X + i * N * K, K, N); + for (int j = 0; j < N; ++j) { + Y_arr(j) += Xi_arr.col(j).square().sum(); + } } - Y_vec = Y_vec.cwiseSqrt() * alpha; + Y_arr = Y_arr.sqrt() * alpha; } template @@ -404,10 +415,10 @@ void RowwiseMoments( T* mean, T* var) { ConstEigenArrayMap X_arr(X, cols, rows); - EigenVectorArrayMap mean_arr(mean, rows); - EigenVectorArrayMap var_arr(var, rows); - mean_arr = X_arr.colwise().mean(); - var_arr = X_arr.square().colwise().mean() - mean_arr.square().transpose(); + for (int i = 0; i < rows; ++i) { + mean[i] = X_arr.col(i).mean(); + var[i] = X_arr.col(i).square().mean() - mean[i] * mean[i]; + } } template @@ -420,7 +431,6 @@ void ColwiseMoments( ConstEigenArrayMap X_arr(X, cols, rows); EigenVectorArrayMap mean_arr(mean, cols); EigenVectorArrayMap var_arr(var, cols); - // Eigen rowwise reduction is about 10 times slower than this for-loop. mean_arr = X_arr.col(0); var_arr = X_arr.col(0).square(); for (int i = 1; i < rows; ++i) { @@ -440,15 +450,19 @@ void BothEndsMoments( const T* X, T* mean, T* var) { + ConstEigenArrayMap X_arr(X, K, M * N); EigenVectorArrayMap mean_arr(mean, N); EigenVectorArrayMap var_arr(var, N); - ConstEigenArrayMap X0_arr(X, K, N); - mean_arr = X0_arr.colwise().sum(); - var_arr = X0_arr.square().colwise().sum(); + for (int i = 0; i < N; ++i) { + mean_arr(i) = X_arr.col(i).sum(); + var_arr(i) = X_arr.col(i).square().sum(); + } for (int i = 1; i < M; ++i) { - ConstEigenArrayMap X_arr(X + i * N * K, K, N); - mean_arr += X_arr.colwise().sum(); - var_arr += X_arr.square().colwise().sum(); + for (int j = 0; j < N; ++j) { + const int c = i * N + j; + mean_arr(j) += X_arr.col(c).sum(); + var_arr(j) += X_arr.col(c).square().sum(); + } } const T scale = T(1) / static_cast(M * K); mean_arr *= scale; -- 2.7.4