From b145dcca048a28b1184be0d90f18aa29d23f0953 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Thu, 4 Apr 2019 11:46:37 -0700 Subject: [PATCH] Add support for group ConvTranspose (#18794) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18794 Add support for group ConvTranspose Reviewed By: houseroad Differential Revision: D14741327 fbshipit-source-id: 5d947ca044bf8495dd7f8f56122441ebbcc6c7e4 --- caffe2/operators/conv_transpose_op_cudnn.cc | 261 ++++-- caffe2/operators/conv_transpose_op_impl.h | 889 +++++++++++---------- caffe2/operators/conv_transpose_unpool_op_base.h | 38 +- caffe2/python/operator_test/conv_transpose_test.py | 63 ++ 4 files changed, 765 insertions(+), 486 deletions(-) diff --git a/caffe2/operators/conv_transpose_op_cudnn.cc b/caffe2/operators/conv_transpose_op_cudnn.cc index 8f8c9a2..459ccd7 100644 --- a/caffe2/operators/conv_transpose_op_cudnn.cc +++ b/caffe2/operators/conv_transpose_op_cudnn.cc @@ -1,7 +1,10 @@ +#include "caffe2/operators/conv_transpose_op.h" + +#include + #include "caffe2/core/context_gpu.h" #include "caffe2/core/cudnn_wrappers.h" #include "caffe2/operators/conv_op_cache_cudnn.h" -#include "caffe2/operators/conv_transpose_op.h" #include "caffe2/operators/op_utils_cudnn.h" namespace caffe2 { @@ -49,6 +52,7 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase { CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&filter_desc_)); if (InputSize() == 3) { CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bias_desc_)); + CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_for_bias_)); } CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_)); CUDNN_ENFORCE(cudnnCreateConvolutionDescriptor(&conv_desc_)); @@ -59,27 +63,59 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase { CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(filter_desc_)); if (InputSize() == 3) { CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bias_desc_)); + CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_for_bias_)); } CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_)); CUDNN_ENFORCE(cudnnDestroyConvolutionDescriptor(conv_desc_)); } protected: - vector cudnn_input_dims_; - vector cudnn_filter_dims_; + void SetTensor4DDescriptorWithGroup( + const cudnnDataType_t data_type, + const int N, + const int C, + const int H, + const int W, + cudnnTensorDescriptor_t* desc) const { +#if CUDNN_VERSION_MIN(7, 0, 0) + const int CC = C; +#else + const int CC = C / group_; +#endif + switch (order_) { + case StorageOrder::NCHW: { + CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx( + *desc, data_type, N, CC, H, W, C * H * W, H * W, W, 1)); + break; + } + case StorageOrder::NHWC: { + CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx( + *desc, data_type, N, CC, H, W, H * W * C, 1, W * C, C)); + break; + } + default: { + LOG(FATAL) << "Unknown storage order: " << order_; + } + } + } + + std::vector cudnn_input_dims_; + std::vector cudnn_filter_dims_; CuDNNWrapper cudnn_wrapper_; cudnnTensorDescriptor_t bottom_desc_; cudnnFilterDescriptor_t filter_desc_; cudnnTensorDescriptor_t bias_desc_; cudnnTensorDescriptor_t top_desc_; + cudnnTensorDescriptor_t top_desc_for_bias_; cudnnConvolutionDescriptor_t conv_desc_; + const size_t cudnn_ws_nbytes_limit_; size_t cudnn_ws_nbytes_; bool exhaustive_search_; bool deterministic_; size_t cudnn_state_; - vector force_algo_; // stored as FWD, dFILTER, dDATA + std::vector force_algo_; // stored as FWD, dFILTER, dDATA bool enable_tensor_core_; }; @@ -141,10 +177,10 @@ bool CudnnConvTransposeOp::RunOnDevice() { int C = 0; switch (order_) { case StorageOrder::NHWC: - C = filter.dim32(3); + C = filter.dim32(3) * group_; break; case StorageOrder::NCHW: - C = filter.dim32(1); + C = filter.dim32(1) * group_; break; default: LOG(FATAL) << "Unknown storage order: " << order_; @@ -162,9 +198,8 @@ bool CudnnConvTransposeOp::RunOnDevice() { H_out = Y->dim32(1); W_out = Y->dim32(2); CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h()); - CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h()); CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w()); - CAFFE_ENFORCE_EQ(filter.dim32(3), C); + CAFFE_ENFORCE_EQ(filter.dim32(3), C / group_); break; case StorageOrder::NCHW: N = X.dim32(0); @@ -173,13 +208,14 @@ bool CudnnConvTransposeOp::RunOnDevice() { W = X.dim32(3); H_out = Y->dim32(2); W_out = Y->dim32(3); - CAFFE_ENFORCE_EQ(filter.dim32(1), C); + CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_); CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h()); CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w()); break; default: LOG(FATAL) << "Unknown storage order: " << order_; } + CAFFE_ENFORCE_EQ(M % group_, 0); if (InputSize() == 3) { auto& bias = Input(BIAS); @@ -188,30 +224,29 @@ bool CudnnConvTransposeOp::RunOnDevice() { } // Set up the cudnn algorithms & workspace if necessary - bool input_changed = (X.sizes() != cudnn_input_dims_); - bool filter_changed = (filter.sizes() != cudnn_filter_dims_); + const bool input_changed = (X.sizes() != cudnn_input_dims_); + const bool filter_changed = (filter.sizes() != cudnn_filter_dims_); if (input_changed || filter_changed) { VLOG(1) << "Changing the cudnn descriptor configurations."; if (input_changed) { cudnn_input_dims_ = X.sizes().vec(); - CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( - bottom_desc_, - GetCudnnTensorFormat(order_), - cudnnTypeWrapper::type, - N, - M, - H, - W)); + SetTensor4DDescriptorWithGroup( + cudnnTypeWrapper::type, N, M, H, W, &bottom_desc_); } if (filter_changed) { cudnn_filter_dims_ = filter.sizes().vec(); +#if CUDNN_VERSION_MIN(7, 0, 0) + const int MM = M; +#else + const int MM = M / group_; +#endif CUDNN_ENFORCE(cudnnSetFilter4dDescriptor( filter_desc_, cudnnTypeWrapper::type, GetCudnnTensorFormat(order_), - M, - C, + MM, + C / group_, kernel_h(), kernel_w())); if (InputSize() == 3) { @@ -226,14 +261,19 @@ bool CudnnConvTransposeOp::RunOnDevice() { } } // Set the output - CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( - top_desc_, - GetCudnnTensorFormat(order_), - cudnnTypeWrapper::type, - N, - C, - H_out, - W_out)); + SetTensor4DDescriptorWithGroup( + cudnnTypeWrapper::type, N, C, H_out, W_out, &top_desc_); + if (InputSize() == 3) { + CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( + top_desc_for_bias_, + GetCudnnTensorFormat(order_), + cudnnTypeWrapper::type, + N, + C, + H_out, + W_out)); + } + // Set the convolution descriptor CAFFE_ENFORCE_EQ( pad_t(), @@ -246,7 +286,7 @@ bool CudnnConvTransposeOp::RunOnDevice() { "The current padding scheme leads to unequal padding on the left " "and right, which is not supported by cudnn."); // Set the convolution descriptor -#if CUDNN_VERSION_MIN(6,0,0) +#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( conv_desc_, pad_t(), @@ -268,6 +308,7 @@ bool CudnnConvTransposeOp::RunOnDevice() { 1, CUDNN_CROSS_CORRELATION)); #endif + #if CUDNN_VERSION_MIN(7, 0, 0) // enable TensorCore math if desired enable_tensor_core_ &= TensorCoreAvailable(); @@ -275,7 +316,10 @@ bool CudnnConvTransposeOp::RunOnDevice() { CUDNN_ENFORCE( cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); } + // set cuDNN groups if appropriate + CUDNN_ENFORCE(cudnnSetConvolutionGroupCount(conv_desc_, group_)); #endif + if (force_algo_[ALGO_DGRAD] >= 0) { bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD]; } else if (deterministic_) { @@ -331,24 +375,56 @@ bool CudnnConvTransposeOp::RunOnDevice() { VLOG(1) << "CuDNN workspace size: " << bwd_data_ws_size; } + const T* X_data = X.template data(); + const T* filter_data = filter.template data(); + T* Y_data = Y->template mutable_data(); + // Now, actually run the computation. // Filter +#if CUDNN_VERSION_MIN(7, 0, 0) cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { CUDNN_ENFORCE(cudnnConvolutionBackwardData( state->cudnn_handle(), cudnnTypeWrapper::kOne(), filter_desc_, - filter.template data(), + filter_data, bottom_desc_, - X.template data(), + X_data, conv_desc_, bwd_data_algo_, state->workspace().get(cudnn_ws_nbytes_), cudnn_ws_nbytes_, cudnnTypeWrapper::kZero(), top_desc_, - Y->template mutable_data())); + Y_data)); }); +#else + const int X_HxW = H * W; + const int Y_HxW = H_out * W_out; + const int group_offset_X = + order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_; + const int group_offset_Y = + order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_; + const int group_offset_filter = filter.numel() / group_; + for (int i = 0; i < group_; ++i) { + cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { + CUDNN_ENFORCE( + cudnnConvolutionBackwardData(state->cudnn_handle(), + cudnnTypeWrapper::kOne(), + filter_desc_, + filter_data + i * group_offset_filter, + bottom_desc_, + X_data + i * group_offset_X; + conv_desc_, + bwd_data_algo_, + state->workspace().get(cudnn_ws_nbytes_), + cudnn_ws_nbytes_, + cudnnTypeWrapper::kZero(), + top_desc_, + Y_data + i * group_offset_Y)); + }); + } +#endif // Bias if (InputSize() == 3) { CUDNN_ENFORCE(cudnnAddTensor( @@ -357,7 +433,7 @@ bool CudnnConvTransposeOp::RunOnDevice() { bias_desc_, Input(BIAS).template data(), cudnnTypeWrapper::kOne(), - top_desc_, + top_desc_for_bias_, Y->template mutable_data())); } // Done. @@ -368,19 +444,19 @@ bool CudnnConvTransposeOp::RunOnDevice() { // consolidating them. template bool CudnnConvTransposeGradientOp::RunOnDevice() { - auto& X = Input(INPUT); - auto& filter = Input(FILTER); - auto& dY = Input(OUTPUT_GRAD); + const auto& X = Input(INPUT); + const auto& filter = Input(FILTER); + const auto& dY = Input(OUTPUT_GRAD); CAFFE_ENFORCE_EQ(X.dim(), 4); CAFFE_ENFORCE_EQ(filter.dim(), 4); int C = 0; switch (order_) { case StorageOrder::NHWC: - C = filter.dim32(3); + C = filter.dim32(3) * group_; break; case StorageOrder::NCHW: - C = filter.dim32(1); + C = filter.dim32(1) * group_; break; default: LOG(FATAL) << "Unknown storage order: " << order_; @@ -398,7 +474,7 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h()); CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h()); CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w()); - CAFFE_ENFORCE_EQ(filter.dim32(3), C); + CAFFE_ENFORCE_EQ(filter.dim32(3), C / group_); break; case StorageOrder::NCHW: N = X.dim32(0); @@ -407,41 +483,42 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { W = X.dim32(3); H_out = dY.dim32(2); W_out = dY.dim32(3); - CAFFE_ENFORCE_EQ(filter.dim32(1), C); + CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_); CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h()); CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w()); break; default: LOG(FATAL) << "Unknown storage order: " << order_; } + CAFFE_ENFORCE_EQ(M % group_, 0); + // Since we only handle LegacyPadding::NOTSET, we don't need to // compute padding. auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); // Set up the cudnn algorithms & workspace if necessary - bool input_changed = (X.sizes() != cudnn_input_dims_); - bool filter_changed = (filter.sizes() != cudnn_filter_dims_); + const bool input_changed = (X.sizes() != cudnn_input_dims_); + const bool filter_changed = (filter.sizes() != cudnn_filter_dims_); if (input_changed || filter_changed) { VLOG(1) << "Changing the cudnn descriptor configurations."; if (input_changed) { cudnn_input_dims_ = X.sizes().vec(); - CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( - bottom_desc_, - GetCudnnTensorFormat(order_), - cudnnTypeWrapper::type, - N, - M, - H, - W)); + SetTensor4DDescriptorWithGroup( + cudnnTypeWrapper::type, N, M, H, W, &bottom_desc_); } if (filter_changed) { cudnn_filter_dims_ = filter.sizes().vec(); +#if CUDNN_VERSION_MIN(7, 0, 0) + const int MM = M; +#else + const int MM = M / group_; +#endif CUDNN_ENFORCE(cudnnSetFilter4dDescriptor( filter_desc_, cudnnTypeWrapper::type, GetCudnnTensorFormat(order_), - M, - C, + MM, + C / group_, kernel_h(), kernel_w())); if (!no_bias_) { @@ -456,14 +533,19 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { } } // Set the output - CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( - top_desc_, - GetCudnnTensorFormat(order_), - cudnnTypeWrapper::type, - N, - C, - H_out, - W_out)); + SetTensor4DDescriptorWithGroup( + cudnnTypeWrapper::type, N, C, H_out, W_out, &top_desc_); + if (!no_bias_) { + CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( + top_desc_for_bias_, + GetCudnnTensorFormat(order_), + cudnnTypeWrapper::type, + N, + C, + H_out, + W_out)); + } + // Set the convolution descriptor CAFFE_ENFORCE_EQ( pad_t(), @@ -475,7 +557,7 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { pad_r(), "The current padding scheme leads to unequal padding on the left " "and right, which is not supported by cudnn."); -#if CUDNN_VERSION_MIN(6,0,0) +#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor( conv_desc_, pad_t(), @@ -504,6 +586,8 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { CUDNN_ENFORCE( cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); } + // set cuDNN groups if appropriate + CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_)); #endif if (force_algo_[ALGO_WGRAD] >= 0) { bwd_filter_algo_ = @@ -622,13 +706,14 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { CUDNN_ENFORCE(cudnnConvolutionBackwardBias( cudnn_wrapper_.inline_cudnn_handle(), cudnnTypeWrapper::kOne(), - top_desc_, + top_desc_for_bias_, dY.template data(), cudnnTypeWrapper::kZero(), bias_desc_, dbias->template mutable_data())); } +#if CUDNN_VERSION_MIN(7, 0, 0) cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { CUDNN_ENFORCE(cudnnConvolutionBackwardFilter( state->cudnn_handle(), @@ -647,7 +732,6 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { // Compute the gradient w.r.t. the input. - auto* dX = Output( no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), @@ -668,6 +752,55 @@ bool CudnnConvTransposeGradientOp::RunOnDevice() { dX->template mutable_data())); } }); +#else + const int X_HxW = H * W; + const int Y_HxW = H_out * W_out; + const int group_offset_X = + order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_; + const int group_offset_Y = + order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_; + const int group_offset_filter = filter.numel() / group_; + for (int i = 0; i < group_; ++i) { + cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { + CUDNN_ENFORCE(cudnnConvolutionBackwardFilter( + state->cudnn_handle(), + cudnnTypeWrapper::kOne(), + top_desc_, + dY.template data() + i * group_offset_Y, + bottom_desc_, + X.template data() + i * group_offset_X, + conv_desc_, + bwd_filter_algo_, + state->workspace().get(cudnn_ws_nbytes_), + cudnn_ws_nbytes_, + cudnnTypeWrapper::kZero(), + filter_desc_, + dfilter->template mutable_data() + i * group_offset_filter)); + if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { + // Compute the gradient w.r.t. the input. + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, + X.sizes(), + at::dtype()); + cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) { + CUDNN_ENFORCE(cudnnConvolutionForward( + state->cudnn_handle(), + cudnnTypeWrapper::kOne(), + top_desc_, + dY.template data() + i * group_offset_Y, + filter_desc_, + filter.template data() + i * group_offset_filter, + conv_desc_, + algo_, + state->workspace().get(cudnn_ws_nbytes_), + cudnn_ws_nbytes_, + cudnnTypeWrapper::kZero(), + bottom_desc_, + dX->template mutable_data() + i * group_offset_X)); + }); + } + } +#endif return true; } diff --git a/caffe2/operators/conv_transpose_op_impl.h b/caffe2/operators/conv_transpose_op_impl.h index 41af81c..333f782 100644 --- a/caffe2/operators/conv_transpose_op_impl.h +++ b/caffe2/operators/conv_transpose_op_impl.h @@ -3,11 +3,15 @@ #ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_IMPL_H_ #define CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_IMPL_H_ +#include "caffe2/operators/conv_transpose_op.h" + +#include +#include + #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/conv_op_shared.h" -#include "caffe2/operators/conv_transpose_op.h" #include "caffe2/operators/conv_transpose_unpool_op_base.h" #include "caffe2/utils/math.h" @@ -17,551 +21,618 @@ namespace caffe2 { template bool ConvTransposeOp::RunOnDeviceWithOrderNCHW() { - const Tensor& X = Input(INPUT); - auto& filter = Input(FILTER); - const int N = X.dim32(0), M = X.dim32(1), H = X.dim32(2), W = X.dim32(3); - CAFFE_ENFORCE(filter.dim() == 4, "filter must be 4D tensor"); - CAFFE_ENFORCE( - filter.dim32(0) == M, - "filter number must be equal to input channel number"); - const int C = filter.dim32(1); - CAFFE_ENFORCE( - filter.dim32(2) == this->kernel_h(), + const auto& X = Input(INPUT); + const auto& filter = Input(FILTER); + CAFFE_ENFORCE_EQ(X.dim(), 4, "Input must be 4D tensor"); + CAFFE_ENFORCE_EQ(filter.dim(), 4, "filter must be 4D tensor"); + const int N = X.dim32(0); + const int M = X.dim32(1); + const int H = X.dim32(2); + const int W = X.dim32(3); + const int G = group_; + CAFFE_ENFORCE_EQ(M, filter.dim32(0)); + CAFFE_ENFORCE_EQ( + M % G, 0, "The number of input channels is not divisible by group."); + const int C = filter.dim32(1) * G; + CAFFE_ENFORCE_EQ( + filter.dim32(2), + kernel_h(), "filter height must be equal to kernel height"); - CAFFE_ENFORCE( - filter.dim32(3) == this->kernel_w(), + CAFFE_ENFORCE_EQ( + filter.dim32(3), + this->kernel_w(), "filter width must be equal to kernel width"); - auto sizes = ConvTransposeUnpoolBase::GetOutputSize(X, C); - Tensor* Y = Output(0, sizes, at::dtype()); + const std::vector Y_dims = + ConvTransposeUnpoolBase::GetOutputSize(X, C); + auto* Y = Output(0, Y_dims, at::dtype()); - const int kernel_dim = C * this->kernel_h() * this->kernel_w(); - const int input_image_size = H * W; - const int output_image_size = Y->dim32(2) * Y->dim32(3); + if (N == 0) { + return true; + } + const int K_HxW = kernel_h() * kernel_w(); + const int kernel_dim = C / G * K_HxW; + const int X_HxW = H * W; + const int Y_HxW = Y->dim32(2) * Y->dim32(3); + + const T* X_data = X.template data(); + const T* filter_data = filter.template data(); + const T* bias_data = nullptr; if (InputSize() == 3) { auto& bias = Input(BIAS); - CAFFE_ENFORCE(bias.dim() == 1, "bias must be 1D tensor"); - CAFFE_ENFORCE( - bias.dim32(0) == C, + CAFFE_ENFORCE_EQ(bias.dim(), 1, "bias must be 1D tensor"); + CAFFE_ENFORCE_EQ( + bias.dim32(0), + C, "bias dimension must be equal to output channel number"); - ReinitializeTensor( - &bias_multiplier_, - {1, output_image_size}, - at::dtype().device(Context::GetDeviceType())); - T* bm_data = bias_multiplier_.template mutable_data(); - math::Set( - output_image_size, - static_cast(1), - bm_data, - &context_); + bias_data = bias.template data(); } + T* Y_data = Y->template mutable_data(); - const T* Xdata = X.template data(); - const T* filter_data = filter.template data(); - T* Ydata = Y->template mutable_data(); + const std::vector buffer_shape = { + C, kernel_h(), kernel_w(), H, W}; - auto f = [&](Tensor* col_buffer) { - ReinitializeTensor(col_buffer, vector{C, this->kernel_h(), this->kernel_w(), H, W}, at::dtype().device(Context::GetDeviceType())); + const auto func = [&](Tensor* col_buffer) { + ReinitializeTensor( + col_buffer, + buffer_shape, + at::dtype().device(Context::GetDeviceType())); T* col_buffer_data = col_buffer->template mutable_data(); - for (auto image_id = 0; image_id < N; ++image_id) { + for (int image_id = 0; image_id < N; ++image_id) { // Weight term - math::Gemm( - CblasTrans, - CblasNoTrans, - kernel_dim, - input_image_size, - M, - 1, - filter_data, - Xdata, - 0, - col_buffer_data, - &context_); + if (G == 1) { + math::Gemm( + CblasTrans, + CblasNoTrans, + kernel_dim, + X_HxW, + M, + 1.0f, + filter_data, + X_data + image_id * M * X_HxW, + 0.0f, + col_buffer_data, + &context_); + } else { + math::GemmStridedBatched( + CblasTrans, + CblasNoTrans, + G, + kernel_dim, + X_HxW, + M / G, + 1.0f, + filter_data, + M / G * kernel_dim, + X_data + image_id * M * X_HxW, + M / G * X_HxW, + 0.0f, + col_buffer_data, + col_buffer->numel() / G, + &context_); + } // Col2Im math::Col2Im( C, Y->dim32(2), Y->dim32(3), - this->kernel_h(), - this->kernel_w(), + kernel_h(), + kernel_w(), 1, 1, - this->pad_t(), - this->pad_l(), - this->pad_b(), - this->pad_r(), - this->stride_h(), - this->stride_w(), + pad_t(), + pad_l(), + pad_b(), + pad_r(), + stride_h(), + stride_w(), col_buffer_data, - Ydata, + Y_data + image_id * C * Y_HxW, &context_); - // Bias term - if (InputSize() == 3) { - const T* bias_data = Input(BIAS).template data(); - const T* bm_data = bias_multiplier_.template data(); -#if !defined(__ARM_NEON__) && !defined(__ARM_NEON) - math::Gemm( - CblasNoTrans, - CblasNoTrans, - C, - output_image_size, - 1, - 1, - bias_data, - bm_data, - 1, - Ydata, - &context_); -#else + if (bias_data != nullptr) { + // Bias term +#if defined(__ARM_NEON__) || defined(__ARM_NEON) math::BiasCHW( bias_data, - bm_data, + nullptr, C, - output_image_size, - Ydata, + Y_HxW, + Y_data + image_id * C * Y_HxW, &context_); #endif // !defined(__ARM_NEON__) && !defined(__ARM_NEON) } - - Xdata += M * H * W; - Ydata += Y->numel() / Y->dim32(0); + } + if (bias_data != nullptr) { +#if !defined(__ARM_NEON__) && !defined(__ARM_NEON) + // Bias term + const std::array Y_dims = {N, C, Y_HxW}; + const std::array b_dims = {1, C, 1}; + math::Add( + 3, + Y_dims.data(), + 3, + b_dims.data(), + Y_data, + bias_data, + Y_data, + &context_); +#endif } }; + if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) { - runWithSharedBuffer(ws_, f); + runWithSharedBuffer(ws_, func); } else { - f(&col_buffer_); + func(&col_buffer_); } return true; } template bool ConvTransposeOp::RunOnDeviceWithOrderNHWC() { - const Tensor& X = Input(INPUT); + const auto& X = Input(INPUT); auto& filter = Input(FILTER); - const auto N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), M = X.dim32(3); - CAFFE_ENFORCE(filter.dim() == 4, "filter must be 4D tensor"); - CAFFE_ENFORCE( - filter.dim32(0) == M, + CAFFE_ENFORCE_EQ(filter.dim(), 4, "filter must be 4D tensor"); + const int N = X.dim32(0); + const int H = X.dim32(1); + const int W = X.dim32(2); + const int M = X.dim32(3); + const int G = group_; + CAFFE_ENFORCE_EQ( + filter.dim32(0), + M, "filter number must be equal to input channel number"); - CAFFE_ENFORCE( - filter.dim32(1) == this->kernel_h(), + CAFFE_ENFORCE_EQ( + M % G, 0, "The number of input channels is not divisible by group."); + const int C = filter.dim32(3) * G; + CAFFE_ENFORCE_EQ( + filter.dim32(1), + kernel_h(), "filter height must be equal to kernel height"); - CAFFE_ENFORCE( - filter.dim32(2) == this->kernel_w(), + CAFFE_ENFORCE_EQ( + filter.dim32(2), + kernel_w(), "filter width must be equal to kernel width"); - const int C = filter.dim32(3); - auto sizes = ConvTransposeUnpoolBase::GetOutputSize(X, C); - Tensor* Y = Output(0, sizes, at::dtype()); - const auto kernel_dim = C * this->kernel_h() * this->kernel_w(); - const auto input_image_size = H * W; - const auto output_image_size = Y->dim32(1) * Y->dim32(2); + const std::vector Y_dims = + ConvTransposeUnpoolBase::GetOutputSize(X, C); + auto* Y = Output(0, Y_dims, at::dtype()); + if (N == 0) { + return true; + } + + const int K_HxW = kernel_h() * kernel_w(); + const int kernel_dim = C / G * K_HxW; + const int X_HxW = H * W; + const int Y_HxW = Y->dim32(1) * Y->dim32(2); + const T* X_data = X.template data(); + const T* filter_data = filter.template data(); + const T* bias_data = nullptr; if (InputSize() == 3) { auto& bias = Input(BIAS); - CAFFE_ENFORCE(bias.dim() == 1, "bias must be 1D tensor"); - CAFFE_ENFORCE( - bias.dim32(0) == C, + CAFFE_ENFORCE_EQ(bias.dim(), 1, "bias must be 1D tensor"); + CAFFE_ENFORCE_EQ( + bias.dim32(0), + C, "bias dimension must be equal to output channel number"); - // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size - ReinitializeTensor( - &bias_multiplier_, - {1, output_image_size}, - at::dtype().device(Context::GetDeviceType())); - T* bm_data = bias_multiplier_.template mutable_data(); - math::Set( - output_image_size, - static_cast(1), - bm_data, - &context_); + bias_data = bias.template data(); } - const T* Xdata = X.template data(); - const T* filter_data = filter.template data(); - T* Ydata = Y->template mutable_data(); + T* Y_data = Y->template mutable_data(); - auto f = [&](Tensor* /*col_buffer*/) { + const std::vector buffer_shape = { + G, H, W, kernel_h(), kernel_w(), C / G}; + const auto func = [&](Tensor* /*col_buffer*/) { ReinitializeTensor( &col_buffer_, - vector{H, W, this->kernel_h(), this->kernel_w(), C}, + buffer_shape, at::dtype().device(Context::GetDeviceType())); T* col_buffer_data = col_buffer_.template mutable_data(); - for (auto image_id = 0; image_id < N; ++image_id) { + for (int image_id = 0; image_id < N; ++image_id) { // Weight term - math::Gemm( - CblasNoTrans, - CblasNoTrans, - input_image_size, - kernel_dim, - M, - 1, - Xdata, - filter_data, - 0, - col_buffer_data, - &context_); + if (G == 1) { + math::Gemm( + CblasNoTrans, + CblasNoTrans, + X_HxW, + kernel_dim, + M, + 1.0f, + X_data + image_id * M * X_HxW, + filter_data, + 0.0f, + col_buffer_data, + &context_); + } else { + for (int group_id = 0; group_id < G; ++group_id) { + math::GemmEx( + CblasNoTrans, + CblasNoTrans, + X_HxW, + kernel_dim, + M / G, + 1.0f, + X_data + image_id * M * X_HxW + group_id * M / G, + M, + filter_data + group_id * M / G * kernel_dim, + kernel_dim, + 0.0f, + col_buffer_data + group_id * kernel_dim, + G * kernel_dim, + &context_); + } + } // Col2Im math::Col2Im( C, Y->dim32(1), Y->dim32(2), - this->kernel_h(), - this->kernel_w(), + kernel_h(), + kernel_w(), 1, 1, - this->pad_t(), - this->pad_l(), - this->pad_b(), - this->pad_r(), - this->stride_h(), - this->stride_w(), + pad_t(), + pad_l(), + pad_b(), + pad_r(), + stride_h(), + stride_w(), col_buffer_data, - Ydata, - &context_); + Y_data + image_id * C * Y_HxW, + &context_, + G); + } + if (bias_data != nullptr) { // Bias term - if (InputSize() == 3) { - const T* bm_data = bias_multiplier_.template data(); - const T* bias_data = Input(BIAS).template data(); - math::Gemm( - CblasNoTrans, - CblasNoTrans, - output_image_size, - C, - 1, - 1, - bm_data, - bias_data, - 1, - Ydata, - &context_); - } - Xdata += M * H * W; - Ydata += Y->numel() / Y->dim32(0); + const std::array Y_dims = {N * Y_HxW, C}; + const std::array b_dims = {1, C}; + math::Add( + 2, + Y_dims.data(), + 2, + b_dims.data(), + Y_data, + bias_data, + Y_data, + &context_); } }; + if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) { - runWithSharedBuffer(ws_, f); + runWithSharedBuffer(ws_, func); } else { - f(&col_buffer_); + func(&col_buffer_); } return true; } template bool ConvTransposeGradientOp::RunOnDeviceWithOrderNCHW() { - auto& X = Input(INPUT); - auto& filter = Input(FILTER); - auto& dY = Input(OUTPUT_GRAD); - - const int N = X.dim32(0), M = X.dim32(1), H = X.dim32(2), W = X.dim32(3); - // We only handle LegacyPadding::NOTSET case and ignore cases of - // LegacyPadding::VALID and LegacyPadding::SAME - // Thus, we don't need to manually compute padding values - // We simply use the values from the user - CAFFE_ENFORCE(filter.dim() == 4); - const int C = filter.dim32(1); - CAFFE_ENFORCE( - filter.dim32(2) == this->kernel_h(), + const auto& X = Input(INPUT); + const auto& filter = Input(FILTER); + const auto& dY = Input(OUTPUT_GRAD); + CAFFE_ENFORCE_EQ(filter.dim(), 4); + const int N = X.dim32(0); + const int M = X.dim32(1); + const int H = X.dim32(2); + const int W = X.dim32(3); + const int G = group_; + CAFFE_ENFORCE_EQ(M, filter.dim32(0)); + CAFFE_ENFORCE_EQ( + M % G, 0, "The number of input channels is not divisible by group."); + const int C = filter.dim32(1) * G; + CAFFE_ENFORCE_EQ(C, dY.dim32(1)); + CAFFE_ENFORCE_EQ( + filter.dim32(2), + kernel_h(), "filter height must be equal to kernel height"); - CAFFE_ENFORCE( - filter.dim32(3) == this->kernel_w(), + CAFFE_ENFORCE_EQ( + filter.dim32(3), + this->kernel_w(), "filter width must be equal to kernel width"); + + const int K_HxW = kernel_h() * kernel_w(); + const int kernel_dim = C / G * K_HxW; + const int X_HxW = H * W; + const int Y_HxW = dY.dim32(2) * dY.dim32(3); auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); - const int kernel_dim = C * this->kernel_h() * this->kernel_w(); - const int output_image_size = dY.dim32(2) * dY.dim32(3); - // The col buffer is stored in CHW order as well - ReinitializeTensor( - &col_buffer_, - vector{C, this->kernel_h(), this->kernel_w(), H, W}, - at::dtype().device(Context::GetDeviceType())); - if (!no_bias_) { - auto* dbias = Output(BIAS_OR_INPUT_GRAD); - dbias->Resize(C); - // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size - ReinitializeTensor( - &bias_multiplier_, - {1, output_image_size}, - at::dtype().device(Context::GetDeviceType())); - T* bm_data = bias_multiplier_.template mutable_data(); - math::Set( - output_image_size, - static_cast(1), - bm_data, - &context_); - } - T* col_buffer_data = col_buffer_.template mutable_data(); - const T* Xdata = X.template data(); + const T* X_data = X.template data(); const T* filter_data = filter.template data(); - const T* dYdata = dY.template data(); + const T* dY_data = dY.template data(); T* dfilter_data = dfilter->template mutable_data(); - // Pre-setting the gradients to zero - math::Set(dfilter->numel(), 0, dfilter_data, &context_); + T* dbias_data = nullptr; + T* dX_data = nullptr; if (!no_bias_) { - auto* dbias = Output(BIAS_OR_INPUT_GRAD); - T* dbias_data = dbias->template mutable_data(); - math::Set(dbias->numel(), 0, dbias_data, &context_); + auto* dbias = Output(BIAS_OR_INPUT_GRAD, {C}, at::dtype()); + dbias_data = dbias->template mutable_data(); } - for (auto image_id = 0; image_id < N; ++image_id) { + const bool compute_dX = + (OutputSize() == 3) || (no_bias_ && (OutputSize() == 2)); + if (compute_dX) { + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); + dX_data = dX->template mutable_data(); + } + math::Set(filter.numel(), T(0), dfilter_data, &context_); + + if (N == 0) { + math::Set(C, T(0), dbias_data, &context_); + return true; + } + + ReinitializeTensor( + &col_buffer_, + std::vector{C, kernel_h(), kernel_w(), H, W}, + at::dtype().device(Context::GetDeviceType())); + T* col_buffer_data = col_buffer_.template mutable_data(); + + for (int image_id = 0; image_id < N; ++image_id) { // gradient w.r.t. filters. Im2Col followed by Gemm // Im2Col. math::Im2Col( C, dY.dim32(2), dY.dim32(3), - this->kernel_h(), - this->kernel_w(), + kernel_h(), + kernel_w(), 1, 1, - this->pad_t(), - this->pad_l(), - this->pad_b(), - this->pad_r(), - this->stride_h(), - this->stride_w(), - dYdata, + pad_t(), + pad_l(), + pad_b(), + pad_r(), + stride_h(), + stride_w(), + dY_data + image_id * C * Y_HxW, col_buffer_data, &context_); // Gemm - math::Gemm( - CblasNoTrans, - CblasTrans, - M, - kernel_dim, - H * W, - 1, - Xdata, - col_buffer_data, - 1, - dfilter_data, - &context_); - // gradient w.r.t. bias - if (!no_bias_) { - const T* bm_data = bias_multiplier_.template data(); - T* input_grad_data = Output(BIAS_OR_INPUT_GRAD)->template mutable_data(); + if (G == 1) { math::Gemm( CblasNoTrans, - CblasNoTrans, - C, - 1, - output_image_size, - 1, - dYdata, - bm_data, - 1, - input_grad_data, - &context_); - } - dYdata += dY.numel() / dY.dim32(0); - Xdata += X.numel() / X.dim32(0); - } - if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { - // Compute gradients w.r.t. the input - // Since we have changed dYdata in the above loop, we will need to reset. - dYdata = dY.template data(); - - auto* dX = Output( - no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); - T* dXdata = dX->template mutable_data(); - for (auto image_id = 0; image_id < N; ++image_id) { - // Im2Col. - // TODO(zyan3): Probably duplicate work as in gradient computation - // w.r.t filters - math::Im2Col( - C, - dY.dim32(2), - dY.dim32(3), - this->kernel_h(), - this->kernel_w(), - 1, - 1, - this->pad_t(), - this->pad_l(), - this->pad_b(), - this->pad_r(), - this->stride_h(), - this->stride_w(), - dYdata, + CblasTrans, + M, + kernel_dim, + X_HxW, + 1.0f, + X_data + image_id * M * X_HxW, col_buffer_data, + 1.0f, + dfilter_data, &context_); - // Gemm - math::Gemm( + } else { + math::GemmStridedBatched( CblasNoTrans, - CblasNoTrans, - M, - H * W, + CblasTrans, + G, + M / G, kernel_dim, - 1, - filter_data, + X_HxW, + 1.0f, + X_data + image_id * M * X_HxW, + M / G * X_HxW, col_buffer_data, - 0, - dXdata, + col_buffer_.numel() / G, + 1.0f, + dfilter_data, + M / G * kernel_dim, &context_); - dYdata += dY.numel() / dY.dim32(0); - dXdata += X.numel() / X.dim32(0); + } + + if (dX_data != nullptr) { + // Compute gradients w.r.t. the input + if (G == 1) { + math::Gemm( + CblasNoTrans, + CblasNoTrans, + M, + X_HxW, + kernel_dim, + 1.0f, + filter_data, + col_buffer_data, + 0.0f, + dX_data + image_id * M * X_HxW, + &context_); + } else { + math::GemmStridedBatched( + CblasNoTrans, + CblasNoTrans, + G, + M / G, + X_HxW, + kernel_dim, + 1.0f, + filter_data, + M / G * kernel_dim, + col_buffer_data, + col_buffer_.numel() / G, + 0.0f, + dX_data + image_id * M * X_HxW, + M / G * X_HxW, + &context_); + } } } + + if (dbias_data != nullptr) { + // gradient w.r.t. bias + const std::array Y_dims = {N, C, Y_HxW}; + const std::array b_dims = {1, C, 1}; + math::ReduceSum( + 3, Y_dims.data(), b_dims.data(), T(1), dY_data, dbias_data, &context_); + } + return true; } template bool ConvTransposeGradientOp::RunOnDeviceWithOrderNHWC() { - auto& X = Input(INPUT); - auto& filter = Input(FILTER); - auto& dY = Input(OUTPUT_GRAD); - - const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), M = X.dim32(3); - // We only handle LegacyPadding::NOTSET case and ignore cases of - // LegacyPadding::VALID and LegacyPadding::SAME - // Thus, we don't need to manually compute padding values - // We simply use the values from the user - CAFFE_ENFORCE(filter.dim() == 4, "filter must be 4D tensor"); - CAFFE_ENFORCE( - filter.dim32(1) == this->kernel_h(), + const auto& X = Input(INPUT); + const auto& filter = Input(FILTER); + const auto& dY = Input(OUTPUT_GRAD); + CAFFE_ENFORCE_EQ(filter.dim(), 4); + const int N = X.dim32(0); + const int H = X.dim32(1); + const int W = X.dim32(2); + const int M = X.dim32(3); + const int G = group_; + CAFFE_ENFORCE_EQ(M, filter.dim32(0)); + CAFFE_ENFORCE_EQ( + M % G, 0, "The number of input channels is not divisible by group."); + const int C = filter.dim32(3) * G; + CAFFE_ENFORCE_EQ(C, dY.dim32(3)); + CAFFE_ENFORCE_EQ( + filter.dim32(1), + kernel_h(), "filter height must be equal to kernel height"); - CAFFE_ENFORCE( - filter.dim32(2) == this->kernel_w(), + CAFFE_ENFORCE_EQ( + filter.dim32(2), + this->kernel_w(), "filter width must be equal to kernel width"); - const int C = filter.dim32(3); + CAFFE_ENFORCE_EQ(dY.dim32(3), C); + + const int K_HxW = kernel_h() * kernel_w(); + const int kernel_dim = C / G * K_HxW; + const int X_HxW = H * W; + const int Y_HxW = dY.dim32(1) * dY.dim32(2); auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); - const int kernel_dim = C * this->kernel_h() * this->kernel_w(); - const int output_image_size = dY.dim32(1) * dY.dim32(2); - // The col buffer is stored in HWC order as well - ReinitializeTensor( - &col_buffer_, - vector{H, W, this->kernel_h(), this->kernel_w(), C}, - at::dtype().device(Context::GetDeviceType())); - if (!no_bias_) { - auto* dbias = Output(BIAS_OR_INPUT_GRAD); - dbias->Resize(C); - // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size - ReinitializeTensor( - &bias_multiplier_, - {1, output_image_size}, - at::dtype().device(Context::GetDeviceType())); - T* bm_data = bias_multiplier_.template mutable_data(); - math::Set( - output_image_size, - static_cast(1), - bm_data, - &context_); - } - T* col_buffer_data = col_buffer_.template mutable_data(); - const T* Xdata = X.template data(); + const T* X_data = X.template data(); const T* filter_data = filter.template data(); - const T* dYdata = dY.template data(); + const T* dY_data = dY.template data(); T* dfilter_data = dfilter->template mutable_data(); - // Pre-setting the gradients to zero - math::Set(dfilter->numel(), 0, dfilter_data, &context_); + T* dbias_data = nullptr; + T* dX_data = nullptr; if (!no_bias_) { - auto* dbias = Output(BIAS_OR_INPUT_GRAD); - T* dbias_data = dbias->template mutable_data(); - math::Set(dbias->numel(), 0, dbias_data, &context_); + auto* dbias = Output(BIAS_OR_INPUT_GRAD, {C}, at::dtype()); + dbias_data = dbias->template mutable_data(); + } + const bool compute_dX = + (OutputSize() == 3) || (no_bias_ && (OutputSize() == 2)); + if (compute_dX) { + auto* dX = Output( + no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); + dX_data = dX->template mutable_data(); } - for (auto image_id = 0; image_id < N; ++image_id) { + math::Set(filter.numel(), T(0), dfilter_data, &context_); + + if (N == 0) { + math::Set(C, T(0), dbias_data, &context_); + return true; + } + + ReinitializeTensor( + &col_buffer_, + std::vector{C, kernel_h(), kernel_w(), H, W}, + at::dtype().device(Context::GetDeviceType())); + T* col_buffer_data = col_buffer_.template mutable_data(); + + for (int image_id = 0; image_id < N; ++image_id) { // gradient w.r.t. filters. Im2Col followed by Gemm // Im2Col. math::Im2Col( C, dY.dim32(1), dY.dim32(2), - this->kernel_h(), - this->kernel_w(), + kernel_h(), + kernel_w(), 1, 1, - this->pad_t(), - this->pad_l(), - this->pad_b(), - this->pad_r(), - this->stride_h(), - this->stride_w(), - dYdata, + pad_t(), + pad_l(), + pad_b(), + pad_r(), + stride_h(), + stride_w(), + dY_data + image_id * C * Y_HxW, col_buffer_data, - &context_); + &context_, + G); // Gemm - math::Gemm( - CblasTrans, - CblasNoTrans, - M, - kernel_dim, - H * W, - 1, - Xdata, - col_buffer_data, - 1, - dfilter_data, - &context_); - // gradients w.r.t. bias - if (!no_bias_) { - const T* bm_data = bias_multiplier_.template data(); - T* input_grad_data = Output(BIAS_OR_INPUT_GRAD)->template mutable_data(); + if (G == 1) { math::Gemm( CblasTrans, CblasNoTrans, - C, - 1, - output_image_size, - 1, - dYdata, - bm_data, - 1, - input_grad_data, - &context_); - } - dYdata += dY.numel() / dY.dim32(0); - Xdata += X.numel() / X.dim32(0); - } - if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) { - // Compute gradients w.r.t. the input - // Since we have changed dYdata in the above loop, we will need to reset. - dYdata = dY.template data(); - - auto* dX = Output( - no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); - T* dXdata = dX->template mutable_data(); - for (auto image_id = 0; image_id < N; ++image_id) { - // Im2Col. - // TODO(zyan3): Probably duplicate work as in gradient computation - // w.r.t filters - math::Im2Col( - C, - dY.dim32(1), - dY.dim32(2), - this->kernel_h(), - this->kernel_w(), - 1, - 1, - this->pad_t(), - this->pad_l(), - this->pad_b(), - this->pad_r(), - this->stride_h(), - this->stride_w(), - dYdata, - col_buffer_data, - &context_); - // Gemm - math::Gemm( - CblasNoTrans, - CblasTrans, - H * W, M, kernel_dim, - 1, + X_HxW, + 1.0f, + X_data + image_id * M * X_HxW, col_buffer_data, - filter_data, - 0, - dXdata, + 1.0f, + dfilter_data, &context_); - dYdata += dY.numel() / dY.dim32(0); - dXdata += X.numel() / X.dim32(0); + } else { + for (int group_id = 0; group_id < G; ++group_id) { + math::GemmEx( + CblasTrans, + CblasNoTrans, + M / G, + kernel_dim, + X_HxW, + 1.0f, + X_data + image_id * M * X_HxW + group_id * M / G, + M, + col_buffer_data + group_id * kernel_dim, + G * kernel_dim, + 1.0f, + dfilter_data + group_id * M / G * kernel_dim, + kernel_dim, + &context_); + } + } + + if (dX_data != nullptr) { + // Compute gradients w.r.t. the input + if (G == 1) { + math::Gemm( + CblasNoTrans, + CblasTrans, + X_HxW, + M, + kernel_dim, + 1.0f, + col_buffer_data, + filter_data, + 0.0f, + dX_data + image_id * M * X_HxW, + &context_); + } else { + for (int group_id = 0; group_id < G; ++group_id) { + math::GemmEx( + CblasNoTrans, + CblasTrans, + X_HxW, + M / G, + kernel_dim, + 1.0f, + col_buffer_data + group_id * kernel_dim, + G * kernel_dim, + filter_data + group_id * M / G * kernel_dim, + kernel_dim, + 0.0f, + dX_data + image_id * M * X_HxW + group_id * M / G, + M, + &context_); + } + } } } + + if (dbias_data != nullptr) { + const std::array Y_dims = {N * Y_HxW, C}; + const std::array b_dims = {1, C}; + math::ReduceSum( + 2, Y_dims.data(), b_dims.data(), T(1), dY_data, dbias_data, &context_); + } + return true; } } // namespace caffe2 + #endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_IMPL_H_ diff --git a/caffe2/operators/conv_transpose_unpool_op_base.h b/caffe2/operators/conv_transpose_unpool_op_base.h index 7ebfda7..c98b3ba 100644 --- a/caffe2/operators/conv_transpose_unpool_op_base.h +++ b/caffe2/operators/conv_transpose_unpool_op_base.h @@ -17,7 +17,9 @@ template class ConvTransposeUnpoolBase : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - explicit ConvTransposeUnpoolBase(const OperatorDef& operator_def, Workspace* ws) + explicit ConvTransposeUnpoolBase( + const OperatorDef& operator_def, + Workspace* ws) : Operator(operator_def, ws), legacy_pad_( static_cast(this->template GetSingleArgument( @@ -27,6 +29,7 @@ class ConvTransposeUnpoolBase : public Operator { stride_(this->template GetRepeatedArgument("strides")), pads_(this->template GetRepeatedArgument("pads")), adj_(this->template GetRepeatedArgument("adjs")), + group_(this->template GetSingleArgument("group", 1)), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))), shared_buffer_( @@ -206,19 +209,7 @@ class ConvTransposeUnpoolBase : public Operator { virtual ~ConvTransposeUnpoolBase() {} - private: - LegacyPadding legacy_pad_; - int pad_; - protected: - vector kernel_; - vector stride_; - vector pads_; - vector adj_; - StorageOrder order_; - bool shared_buffer_; - Workspace* ws_; - // Accessors for 2D conv params. inline int pad_t() const { @@ -289,14 +280,35 @@ class ConvTransposeUnpoolBase : public Operator { break; } } + + LegacyPadding legacy_pad_; + int pad_; + + std::vector kernel_; + std::vector stride_; + std::vector pads_; + std::vector adj_; + int group_; + StorageOrder order_; + bool shared_buffer_; + Workspace* ws_; }; #define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \ USE_OPERATOR_FUNCTIONS(Context); \ using ConvTransposeUnpoolBase::kernel_; \ + using ConvTransposeUnpoolBase::kernel_h; \ + using ConvTransposeUnpoolBase::kernel_w; \ using ConvTransposeUnpoolBase::stride_; \ + using ConvTransposeUnpoolBase::stride_h; \ + using ConvTransposeUnpoolBase::stride_w; \ using ConvTransposeUnpoolBase::pads_; \ + using ConvTransposeUnpoolBase::pad_t; \ + using ConvTransposeUnpoolBase::pad_l; \ + using ConvTransposeUnpoolBase::pad_b; \ + using ConvTransposeUnpoolBase::pad_r; \ using ConvTransposeUnpoolBase::adj_; \ + using ConvTransposeUnpoolBase::group_; \ using ConvTransposeUnpoolBase::order_; \ using ConvTransposeUnpoolBase::shared_buffer_; \ using ConvTransposeUnpoolBase::ws_ diff --git a/caffe2/python/operator_test/conv_transpose_test.py b/caffe2/python/operator_test/conv_transpose_test.py index 2a32f9a..272ac3a 100644 --- a/caffe2/python/operator_test/conv_transpose_test.py +++ b/caffe2/python/operator_test/conv_transpose_test.py @@ -6,6 +6,7 @@ import numpy as np from hypothesis import assume, given, settings import hypothesis.strategies as st +from caffe2.proto import caffe2_pb2 from caffe2.python import core, utils import caffe2.python.hypothesis_test_util as hu import caffe2.python.hip_test_util as hiputl @@ -360,6 +361,68 @@ class TestConvolutionTranspose(hu.HypothesisTestCase): for i in outputs_to_check: self.assertGradientChecks(gc, op, inputs, i, [0]) + @given(stride=st.integers(1, 3), + pad=st.integers(0, 3), + kernel=st.integers(1, 3), + adj=st.integers(0, 2), + size=st.integers(7, 10), + input_channels=st.integers(1, 8), + output_channels=st.integers(1, 8), + batch_size=st.integers(1, 4), + group=st.integers(1, 4), + order=st.sampled_from(["NCHW", "NHWC"]), + engine=st.sampled_from(["", "CUDNN", "BLOCK"]), + shared_buffer=st.booleans(), + use_bias=st.booleans(), + **hu.gcs) + def test_convolution_transpose_with_group( + self, stride, pad, kernel, adj, size, input_channels, + output_channels, batch_size, group, order, engine, shared_buffer, + use_bias, gc, dc): + assume(adj < stride) + # TODO: Group conv_transpose in NHWC not implemented for GPU yet. + assume(group == 1 or order == "NCHW" or + gc.device_type == caffe2_pb2.CPU) + if group != 1 and order == "NHWC": + dc = [d for d in dc if d.device_type == caffe2_pb2.CPU] + + if hiputl.run_in_hip(gc, dc) and order == "NHWC": + engine = "" + + op = core.CreateOperator( + "ConvTranspose", + ["X", "w", "b"] if use_bias else ["X", "w"], + ["Y"], + stride=stride, + kernel=kernel, + pad=pad, + adj=adj, + group=group, + order=order, + engine=engine, + shared_buffer=int(shared_buffer), + device_option=gc, + ) + + input_channels *= group + output_channels *= group + + X = np.random.rand( + batch_size, size, size, input_channels).astype(np.float32) - 0.5 + w = np.random.rand( + input_channels, kernel, kernel, int(output_channels / group)) \ + .astype(np.float32) - 0.5 + b = np.random.rand(output_channels).astype(np.float32) - 0.5 + if order == "NCHW": + X = utils.NHWC2NCHW(X) + w = utils.NHWC2NCHW(w) + + inputs = [X, w, b] if use_bias else [X, w] + self.assertDeviceChecks(dc, op, inputs, [0]) + for i in range(len(inputs)): + self.assertGradientChecks(gc, op, inputs, i, [0]) + + if __name__ == "__main__": import unittest unittest.main() -- 2.7.4