#include <cub/block/block_reduce.cuh>
#include "caffe2/core/context_gpu.h"
-#include "caffe2/utils/math_utils.h"
+#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T>
using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
+template <typename T, int kBlockDimX, int kBlockDimY>
+using BlockReduce2D = cub::
+ BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
+
template <typename T>
__global__ void ComputeFusedParamsCUDAKernel(
- const int N,
const int G,
- const int D,
+ const int K,
const T* mu,
const T* rsig,
const T* gamma,
const T* beta,
T* scale,
T* bias) {
- const int outer_size = N * G;
- for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
- const int g = i % G;
+ const int n = blockIdx.x;
+ const int g = blockIdx.y;
+ const int i_mu = n * G + g;
+ for (int i = threadIdx.x; i < K; i += blockDim.x) {
+ const int index = i_mu * K + i;
+ const int i_gamma = g * K + i;
#if __CUDA_ARCH__ >= 350
- const T mu_val = __ldg(mu + i);
- const T rsig_val = __ldg(rsig + i);
+ const T scale_val = __ldg(gamma + i_gamma) * __ldg(rsig + i_mu);
+ scale[index] = scale_val;
+ bias[index] = __ldg(beta + i_gamma) - scale_val * __ldg(mu + i_mu);
#else
- const T mu_val = mu[i];
- const T rsig_val = rsig[i];
+ const T scale_val = gamma[i_gamma] * rsig[i_mu];
+ scale[index] = scale_val;
+ bias[index] = beta[i_gamma] - scale_val * mu[i_mu];
#endif
- for (int j = threadIdx.x; j < D; j += blockDim.x) {
- const int index = i * D + j;
- const int i_gamma = g * D + j;
-#if __CUDA_ARCH__ >= 350
- const T scale_val = __ldg(gamma + i_gamma) * rsig_val;
- scale[index] = scale_val;
- bias[index] = __ldg(beta + i_gamma) - scale_val * mu_val;
-#else
- const T scale_val = gamma[i_gamma] * rsig_val;
- scale[index] = scale_val;
- bias[index] = beta[i_gamma] - scale_val * mu_val;
-#endif
- }
}
}
-template <typename T, StorageOrder kOrder>
-__global__ void GroupNormForwardCUDAKernel(
- const int N,
- const int C,
+template <typename T>
+__global__ void GroupNormForwardNCHWCUDAKernel(
+ const int K,
const int HxW,
const T* X,
const T* scale,
T* Y);
template <>
-__global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>(
- const int N,
- const int C,
+__global__ void GroupNormForwardNCHWCUDAKernel<float>(
+ const int W,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
- const int outer_size = N * C;
- for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
-#if __CUDA_ARCH__ >= 350
- const float scale_val = __ldg(scale + i);
- const float bias_val = __ldg(bias + i);
-#else
- const float scale_val = scale[i];
- const float bias_val = bias[i];
-#endif
- for (int j = threadIdx.x; j < HxW; j += blockDim.x) {
- const int index = i * HxW + j;
+ const int nc = blockIdx.x / W;
+ const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (hw < HxW) {
+ const int index = nc * HxW + hw;
#if __CUDA_ARCH__ >= 350
- Y[index] = __ldg(X + index) * scale_val + bias_val;
+ Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
#else
- Y[index] = X[index] * scale_val + bias_val;
+ Y[index] = fmaf(X[index], scale[nc], bias[nc]);
#endif
- }
}
}
+template <typename T>
+__global__ void GroupNormForwardNHWCCUDAKernel(
+ const int C,
+ const int HxW,
+ const T* X,
+ const T* scale,
+ const T* bias,
+ T* Y);
+
template <>
-__global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>(
- const int N,
+__global__ void GroupNormForwardNHWCCUDAKernel<float>(
const int C,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
- const int outer_size = N * HxW;
- for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
- const int n = i / HxW;
- for (int j = threadIdx.x; j < C; j += blockDim.x) {
- const int index = i * C + j;
- const int i_scale = n * C + j;
+ const int n = blockIdx.x / HxW;
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ const int index = blockIdx.x * C + c;
+ const int nc = n * C + c;
#if __CUDA_ARCH__ >= 350
- Y[index] =
- __ldg(X + index) * __ldg(scale + i_scale) + __ldg(bias + i_scale);
+ Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
#else
- Y[index] = X[index] * scale[i_scale] + bias[i_scale];
+ Y[index] = fmaf(X[index], scale[nc], bias[nc]);
#endif
- }
}
}
-template <typename T, StorageOrder kOrder>
-__global__ void ComputeInternalGradientsCUDAKernel(
- const int N,
+template <typename T>
+__global__ void ComputeInternalGradientsNCHWCUDAKernel(
const int G,
- const int D,
+ const int K,
const int HxW,
const T* dY,
const T* X,
const T* gamma,
T* ds,
T* db) {
- const int outer_size = N * G;
- const int inner_size = D * HxW;
__shared__ typename BlockReduce<T>::TempStorage ds_storage;
__shared__ typename BlockReduce<T>::TempStorage db_storage;
- for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
- T ds_val = 0;
- T db_val = 0;
- for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
- const int i_gamma = i % G * D + j / HxW;
- const int index = kOrder == StorageOrder::NCHW
- ? i * inner_size + j
- : (i / G * HxW + j % HxW) * G * D + i_gamma;
+ const int inner_size = K * HxW;
+ const int n = blockIdx.x;
+ const int g = blockIdx.y;
+ const int ng = n * G + g;
+ T ds_val = 0;
+ T db_val = 0;
+ for (int i = threadIdx.x; i < inner_size; i += blockDim.x) {
+ const int c = g * K + i / HxW;
+ const int index = ng * inner_size + i;
#if __CUDA_ARCH__ >= 350
- ds_val += __ldg(gamma + i_gamma) * __ldg(dY + index) * __ldg(X + index);
- db_val += __ldg(gamma + i_gamma) * __ldg(dY + index);
+ ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
+ db_val += __ldg(gamma + c) * __ldg(dY + index);
#else
- ds_val += gamma[i_gamma] * dY[index] * X[index];
- db_val += gamma[i_gamma] * dY[index];
+ ds_val += gamma[c] * dY[index] * X[index];
+ db_val += gamma[c] * dY[index];
+#endif
+ }
+ ds_val = BlockReduce<T>(ds_storage).Sum(ds_val);
+ db_val = BlockReduce<T>(db_storage).Sum(db_val);
+ if (threadIdx.x == 0) {
+ ds[ng] = ds_val;
+ db[ng] = db_val;
+ }
+}
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+__global__ void ComputeInternalGradientsNHWCCUDAKernel(
+ const int G,
+ const int K,
+ const int HxW,
+ const T* dY,
+ const T* X,
+ const T* gamma,
+ T* ds,
+ T* db) {
+ __shared__
+ typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage m_storage;
+ __shared__
+ typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage v_storage;
+ const int C = G * K;
+ const int n = blockIdx.x;
+ const int g = blockIdx.y;
+ const int ng = n * G + g;
+ T ds_val = 0;
+ T db_val = 0;
+ for (int i = threadIdx.x; i < HxW; i += blockDim.x) {
+ for (int j = threadIdx.y; j < K; j += blockDim.y) {
+ const int c = g * K + j;
+ const int index = (n * HxW + i) * C + c;
+#if __CUDA_ARCH__ >= 350
+ ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
+ db_val += __ldg(gamma + c) * __ldg(dY + index);
+#else
+ ds_val += gamma[c] * dY[index] * X[index];
+ db_val += gamma[c] * dY[index];
#endif
}
- ds_val = BlockReduce<T>(ds_storage).Reduce(ds_val, cub::Sum());
- db_val = BlockReduce<T>(db_storage).Reduce(db_val, cub::Sum());
- if (threadIdx.x == 0) {
- ds[i] = ds_val;
- db[i] = db_val;
- }
- __syncthreads();
+ }
+ ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(m_storage).Sum(ds_val);
+ db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(v_storage).Sum(db_val);
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
+ ds[ng] = ds_val;
+ db[ng] = db_val;
}
}
// db/dX = -u * drsig/dX - rsig * dmu/dX
// drsig/dX = -rsig^3 * (X - mu) / n
// dmu/dX = 1 / n
-template <typename T, StorageOrder kOrder>
-__global__ void GroupNormBackwardCUDAKernel(
- const int size,
+template <typename T>
+__global__ void GroupNormBackwardNCHWCUDAKernel(
+ const int G,
+ const int K,
+ const int W,
+ const int HxW,
+ const T* dY,
+ const T* X,
+ const T* mu,
+ const T* rsig,
+ const T* gamma,
+ const T* ds,
+ const T* db,
+ T* dX) {
+ const int C = G * K;
+ const T denom = T(1) / static_cast<T>(K * HxW);
+ const int nc = blockIdx.x / W;
+ const int n = nc / C;
+ const int c = nc % C;
+ const int g = c / K;
+ const int ng = n * G + g;
+ const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ const int index = nc * HxW + hw;
+ if (hw < HxW) {
+#if __CUDA_ARCH__ >= 350
+ const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) *
+ (__ldg(X + index) - __ldg(mu + ng)) *
+ math::utils::Cube<T>(__ldg(rsig + ng));
+ const T v = __ldg(db + ng) * __ldg(rsig + ng);
+ dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) +
+ (u - v) * denom;
+#else
+ const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) *
+ math::utils::Cube<T>(rsig[ng]);
+ const T v = db[ng] * rsig[ng];
+ dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom;
+#endif
+ }
+}
+
+template <typename T>
+__global__ void GroupNormBackwardNHWCCUDAKernel(
const int G,
- const int D,
+ const int K,
const int HxW,
const T* dY,
const T* X,
const T* ds,
const T* db,
T* dX) {
- const int C = G * D;
- const T denom = T(1) / static_cast<T>(D * HxW);
- CUDA_1D_KERNEL_LOOP(i, size) {
- const int i_mu = kOrder == StorageOrder::NCHW
- ? i / (D * HxW)
- : i / (C * HxW) * G + (i / D % G);
- const int i_gamma = kOrder == StorageOrder::NCHW ? (i / HxW) % C : i % C;
+ const int C = G * K;
+ const T denom = T(1) / static_cast<T>(K * HxW);
+ const int x = blockIdx.x;
+ const int g = blockIdx.y;
+ const int n = x / HxW;
+ const int ng = n * G + g;
+ for (int i = threadIdx.x; i < K; i += blockDim.x) {
+ const int c = g * K + i;
+ const int index = x * C + c;
#if __CUDA_ARCH__ >= 350
- const T u = (__ldg(db + i_mu) * __ldg(mu + i_mu) - __ldg(ds + i_mu)) *
- (__ldg(X + i) - __ldg(mu + i_mu)) *
- math::utils::Cube<T>(__ldg(rsig + i_mu));
- const T v = __ldg(db + i_mu) * __ldg(rsig + i_mu);
- dX[i] = __ldg(gamma + i_gamma) * __ldg(dY + i) * __ldg(rsig + i_mu) +
+ const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) *
+ (__ldg(X + index) - __ldg(mu + ng)) *
+ math::utils::Cube<T>(__ldg(rsig + ng));
+ const T v = __ldg(db + ng) * __ldg(rsig + ng);
+ dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) +
(u - v) * denom;
#else
- const T u = (db[i_mu] * mu[i_mu] - ds[i_mu]) * (X[i] - mu[i_mu]) *
- math::utils::Cube<T>(rsig[i_mu]);
- const T v = db[i_mu] * rsig[i_mu];
- dX[i] = gamma[i_gamma] * dY[i] * rsig[i_mu] + (u - v) * denom;
+ const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) *
+ math::utils::Cube<T>(rsig[ng]);
+ const T v = db[ng] * rsig[ng];
+ dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom;
#endif
}
}
-template <typename T, StorageOrder kOrder>
-__global__ void GammaBetaBackwardCUDAKernel(
+template <typename T, int kBlockDimX, int kBlockDimY>
+__global__ void GammaBetaBackwardNCHWCUDAKernel(
const int N,
const int G,
- const int D,
+ const int K,
const int HxW,
const T* dY,
const T* X,
const T* rsig,
T* dgamma,
T* dbeta) {
- const int outer_size = G * D;
- const int inner_size = N * HxW;
- __shared__ typename BlockReduce<T>::TempStorage dg_storage;
- __shared__ typename BlockReduce<T>::TempStorage db_storage;
- for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
- T dg_val = 0;
- T db_val = 0;
- for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
- const int n = j / HxW;
- const int index = kOrder == StorageOrder::NCHW
- ? (n * outer_size + i) * HxW + j % HxW
- : j * outer_size + i;
- const int i_mu = n * G + i / D;
+ __shared__
+ typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage dg_storage;
+ __shared__
+ typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
+ const int C = G * K;
+ const int c = blockIdx.x;
+ const int g = c / K;
+ T dg_val = 0;
+ T db_val = 0;
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
+ for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
+ const int index = (i * C + c) * HxW + j;
+ const int ng = i * G + g;
#if __CUDA_ARCH__ >= 350
- dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + i_mu)) *
- __ldg(rsig + i_mu);
+ dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) *
+ __ldg(rsig + ng);
db_val += __ldg(dY + index);
#else
- dg_val += dY[index] * (X[index] - mu[i_mu]) * rsig[i_mu];
+ dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng];
db_val += dY[index];
#endif
}
- dg_val = BlockReduce<T>(dg_storage).Reduce(dg_val, cub::Sum());
- db_val = BlockReduce<T>(db_storage).Reduce(db_val, cub::Sum());
- if (threadIdx.x == 0) {
- dgamma[i] = dg_val;
- dbeta[i] = db_val;
+ }
+ dg_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(dg_storage).Sum(dg_val);
+ db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
+ dgamma[c] = dg_val;
+ dbeta[c] = db_val;
+ }
+}
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+__global__ void GammaBetaBackwardNHWCCUDAKernel(
+ const int N,
+ const int G,
+ const int K,
+ const int HxW,
+ const T* dY,
+ const T* X,
+ const T* mu,
+ const T* rsig,
+ T* dgamma,
+ T* dbeta) {
+ __shared__
+ typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage dg_storage;
+ __shared__
+ typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
+ const int C = G * K;
+ const int c = blockIdx.x;
+ const int g = c / K;
+ T dg_val = 0;
+ T db_val = 0;
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
+ for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
+ const int index = (i * HxW + j) * C + c;
+ const int ng = i * G + g;
+#if __CUDA_ARCH__ >= 350
+ dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) *
+ __ldg(rsig + ng);
+ db_val += __ldg(dY + index);
+#else
+ dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng];
+ db_val += dY[index];
+#endif
}
- __syncthreads();
+ }
+ dg_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(dg_storage).Sum(dg_val);
+ db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
+ dgamma[c] = dg_val;
+ dbeta[c] = db_val;
}
}
void GroupNormOp<float, CUDAContext>::ComputeFusedParams(
const int N,
const int G,
- const int D,
+ const int K,
const float* mu,
const float* rsig,
const float* gamma,
float* scale,
float* bias) {
ComputeFusedParamsCUDAKernel<float>
- <<<std::min(N * G, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(N, G, D, mu, rsig, gamma, beta, scale, bias);
+ <<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ G, K, mu, rsig, gamma, beta, scale, bias);
}
template <>
const float* scale,
const float* bias,
float* Y) {
- GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>
- <<<std::min(N * C, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(N, C, HxW, X, scale, bias, Y);
+ const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+ GroupNormForwardNCHWCUDAKernel<float>
+ <<<N * C * W, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ W, HxW, X, scale, bias, Y);
}
template <>
const float* scale,
const float* bias,
float* Y) {
- GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>
- <<<std::min(N * HxW, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(N, C, HxW, X, scale, bias, Y);
+ GroupNormForwardNHWCCUDAKernel<float>
+ <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ C, HxW, X, scale, bias, Y);
}
// Math:
bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
const int N,
const int G,
- const int D,
+ const int K,
const int HxW,
const float* dY_data,
const float* X_data,
float* dX_data,
float* dgamma_data,
float* dbeta_data) {
- const int size = N * G * D * HxW;
- const int C = G * D;
- ReinitializeTensor(
- &ds_, {N, G}, at::dtype<float>().device(CUDA));
- ReinitializeTensor(
- &db_, {N, G}, at::dtype<float>().device(CUDA));
+ const int C = G * K;
+ ReinitializeTensor(&ds_, {N, G}, at::dtype<float>().device(CUDA));
+ ReinitializeTensor(&db_, {N, G}, at::dtype<float>().device(CUDA));
float* ds_data = ds_.mutable_data<float>();
float* db_data = db_.mutable_data<float>();
if (order_ == StorageOrder::NCHW) {
// Computes dL/ds and dL/db.
// dL/ds = Sum(dL/dY * gamma * X)
// dL/db = Sum(dL/dY * gamma)
- ComputeInternalGradientsCUDAKernel<float, StorageOrder::NCHW>
- <<<std::min(N * G, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- N, G, D, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+ ComputeInternalGradientsNCHWCUDAKernel<float>
+ <<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
// Computes dL/dX.
- GroupNormBackwardCUDAKernel<float, StorageOrder::NCHW>
- <<<CAFFE_GET_BLOCKS(size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- size,
+ const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+ GroupNormBackwardNCHWCUDAKernel<float>
+ <<<N * C * W, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
G,
- D,
+ K,
+ W,
HxW,
dY_data,
X_data,
dX_data);
// Computes dL/dgamma and dL/dbeta.
- GammaBetaBackwardCUDAKernel<float, StorageOrder::NCHW>
- <<<std::min(C, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- N,
- G,
- D,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
+ if (HxW >= 128) {
+ GammaBetaBackwardNCHWCUDAKernel<float, 1, 128>
+ <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
+ } else if (HxW >= 64) {
+ GammaBetaBackwardNCHWCUDAKernel<float, 2, 64>
+ <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
+ } else if (HxW >= 32) {
+ GammaBetaBackwardNCHWCUDAKernel<float, 4, 32>
+ <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
+ } else {
+ GammaBetaBackwardNCHWCUDAKernel<float, 8, 16>
+ <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
+ }
} else {
// Computes dL/ds and dL/db.
// dL/ds = Sum(dL/dY * gamma * X)
// dL/db = Sum(dL/dY * gamma)
- ComputeInternalGradientsCUDAKernel<float, StorageOrder::NHWC>
- <<<std::min(N * G, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- N, G, D, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+ if (K >= 128) {
+ ComputeInternalGradientsNHWCCUDAKernel<float, 1, 128>
+ <<<dim3(N, G), dim3(1, 128), 0, context_.cuda_stream()>>>(
+ G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+ } else if (K >= 64) {
+ ComputeInternalGradientsNHWCCUDAKernel<float, 2, 64>
+ <<<dim3(N, G), dim3(2, 64), 0, context_.cuda_stream()>>>(
+ G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+ } else if (K >= 32) {
+ ComputeInternalGradientsNHWCCUDAKernel<float, 4, 32>
+ <<<dim3(N, G), dim3(4, 32), 0, context_.cuda_stream()>>>(
+ G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+ } else {
+ ComputeInternalGradientsNHWCCUDAKernel<float, 8, 16>
+ <<<dim3(N, G), dim3(8, 16), 0, context_.cuda_stream()>>>(
+ G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+ }
// Computes dL/dX.
- GroupNormBackwardCUDAKernel<float, StorageOrder::NHWC>
- <<<CAFFE_GET_BLOCKS(size),
+ GroupNormBackwardNHWCCUDAKernel<float>
+ <<<dim3(N * HxW, G),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
- size,
G,
- D,
+ K,
HxW,
dY_data,
X_data,
dX_data);
// Computes dL/dgamma and dL/dbeta.
- GammaBetaBackwardCUDAKernel<float, StorageOrder::NHWC>
- <<<std::min(C, CAFFE_MAXIMUM_NUM_BLOCKS),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- N,
- G,
- D,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
+ if (HxW >= 128) {
+ GammaBetaBackwardNHWCCUDAKernel<float, 1, 128>
+ <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
+ } else if (HxW >= 64) {
+ GammaBetaBackwardNHWCCUDAKernel<float, 2, 64>
+ <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
+ } else if (HxW >= 32) {
+ GammaBetaBackwardNHWCCUDAKernel<float, 4, 32>
+ <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
+ } else {
+ GammaBetaBackwardNHWCCUDAKernel<float, 8, 16>
+ <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
+ }
}
return true;
}