Add support for group ConvTranspose (#18794)
authorXiaomeng Yang <yangxm@fb.com>
Thu, 4 Apr 2019 18:46:37 +0000 (11:46 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 18:52:06 +0000 (11:52 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18794

Add support for group ConvTranspose

Reviewed By: houseroad

Differential Revision: D14741327

fbshipit-source-id: 5d947ca044bf8495dd7f8f56122441ebbcc6c7e4

caffe2/operators/conv_transpose_op_cudnn.cc
caffe2/operators/conv_transpose_op_impl.h
caffe2/operators/conv_transpose_unpool_op_base.h
caffe2/python/operator_test/conv_transpose_test.py

index 8f8c9a2..459ccd7 100644 (file)
@@ -1,7 +1,10 @@
+#include "caffe2/operators/conv_transpose_op.h"
+
+#include <vector>
+
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/core/cudnn_wrappers.h"
 #include "caffe2/operators/conv_op_cache_cudnn.h"
-#include "caffe2/operators/conv_transpose_op.h"
 #include "caffe2/operators/op_utils_cudnn.h"
 
 namespace caffe2 {
@@ -49,6 +52,7 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
     CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&filter_desc_));
     if (InputSize() == 3) {
       CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bias_desc_));
+      CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_for_bias_));
     }
     CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
     CUDNN_ENFORCE(cudnnCreateConvolutionDescriptor(&conv_desc_));
@@ -59,27 +63,59 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
     CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(filter_desc_));
     if (InputSize() == 3) {
       CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bias_desc_));
+      CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_for_bias_));
     }
     CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
     CUDNN_ENFORCE(cudnnDestroyConvolutionDescriptor(conv_desc_));
   }
 
  protected:
-  vector<int64_t> cudnn_input_dims_;
-  vector<int64_t> cudnn_filter_dims_;
+  void SetTensor4DDescriptorWithGroup(
+      const cudnnDataType_t data_type,
+      const int N,
+      const int C,
+      const int H,
+      const int W,
+      cudnnTensorDescriptor_t* desc) const {
+#if CUDNN_VERSION_MIN(7, 0, 0)
+    const int CC = C;
+#else
+    const int CC = C / group_;
+#endif
+    switch (order_) {
+      case StorageOrder::NCHW: {
+        CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx(
+            *desc, data_type, N, CC, H, W, C * H * W, H * W, W, 1));
+        break;
+      }
+      case StorageOrder::NHWC: {
+        CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx(
+            *desc, data_type, N, CC, H, W, H * W * C, 1, W * C, C));
+        break;
+      }
+      default: {
+        LOG(FATAL) << "Unknown storage order: " << order_;
+      }
+    }
+  }
+
+  std::vector<std::int64_t> cudnn_input_dims_;
+  std::vector<std::int64_t> cudnn_filter_dims_;
 
   CuDNNWrapper cudnn_wrapper_;
   cudnnTensorDescriptor_t bottom_desc_;
   cudnnFilterDescriptor_t filter_desc_;
   cudnnTensorDescriptor_t bias_desc_;
   cudnnTensorDescriptor_t top_desc_;
+  cudnnTensorDescriptor_t top_desc_for_bias_;
   cudnnConvolutionDescriptor_t conv_desc_;
+
   const size_t cudnn_ws_nbytes_limit_;
   size_t cudnn_ws_nbytes_;
   bool exhaustive_search_;
   bool deterministic_;
   size_t cudnn_state_;
-  vector<int> force_algo_; // stored as FWD, dFILTER, dDATA
+  std::vector<int> force_algo_; // stored as FWD, dFILTER, dDATA
   bool enable_tensor_core_;
 };
 
@@ -141,10 +177,10 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
   int C = 0;
   switch (order_) {
     case StorageOrder::NHWC:
-      C = filter.dim32(3);
+      C = filter.dim32(3) * group_;
       break;
     case StorageOrder::NCHW:
-      C = filter.dim32(1);
+      C = filter.dim32(1) * group_;
       break;
     default:
       LOG(FATAL) << "Unknown storage order: " << order_;
@@ -162,9 +198,8 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
       H_out = Y->dim32(1);
       W_out = Y->dim32(2);
       CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
-      CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
       CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
-      CAFFE_ENFORCE_EQ(filter.dim32(3), C);
+      CAFFE_ENFORCE_EQ(filter.dim32(3), C / group_);
       break;
     case StorageOrder::NCHW:
       N = X.dim32(0);
@@ -173,13 +208,14 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
       W = X.dim32(3);
       H_out = Y->dim32(2);
       W_out = Y->dim32(3);
-      CAFFE_ENFORCE_EQ(filter.dim32(1), C);
+      CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_);
       CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
       CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
       break;
     default:
       LOG(FATAL) << "Unknown storage order: " << order_;
   }
+  CAFFE_ENFORCE_EQ(M % group_, 0);
 
   if (InputSize() == 3) {
     auto& bias = Input(BIAS);
@@ -188,30 +224,29 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
   }
 
   // Set up the cudnn algorithms & workspace if necessary
-  bool input_changed = (X.sizes() != cudnn_input_dims_);
-  bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
+  const bool input_changed = (X.sizes() != cudnn_input_dims_);
+  const bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
 
   if (input_changed || filter_changed) {
     VLOG(1) << "Changing the cudnn descriptor configurations.";
     if (input_changed) {
       cudnn_input_dims_ = X.sizes().vec();
-      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
-          bottom_desc_,
-          GetCudnnTensorFormat(order_),
-          cudnnTypeWrapper<T>::type,
-          N,
-          M,
-          H,
-          W));
+      SetTensor4DDescriptorWithGroup(
+          cudnnTypeWrapper<T>::type, N, M, H, W, &bottom_desc_);
     }
     if (filter_changed) {
       cudnn_filter_dims_ = filter.sizes().vec();
+#if CUDNN_VERSION_MIN(7, 0, 0)
+      const int MM = M;
+#else
+      const int MM = M / group_;
+#endif
       CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
           filter_desc_,
           cudnnTypeWrapper<T>::type,
           GetCudnnTensorFormat(order_),
-          M,
-          C,
+          MM,
+          C / group_,
           kernel_h(),
           kernel_w()));
       if (InputSize() == 3) {
@@ -226,14 +261,19 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
       }
     }
     // Set the output
-    CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
-        top_desc_,
-        GetCudnnTensorFormat(order_),
-        cudnnTypeWrapper<T>::type,
-        N,
-        C,
-        H_out,
-        W_out));
+    SetTensor4DDescriptorWithGroup(
+        cudnnTypeWrapper<T>::type, N, C, H_out, W_out, &top_desc_);
+    if (InputSize() == 3) {
+      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
+          top_desc_for_bias_,
+          GetCudnnTensorFormat(order_),
+          cudnnTypeWrapper<T>::type,
+          N,
+          C,
+          H_out,
+          W_out));
+    }
+
     // Set the convolution descriptor
     CAFFE_ENFORCE_EQ(
         pad_t(),
@@ -246,7 +286,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
         "The current padding scheme leads to unequal padding on the left "
         "and right, which is not supported by cudnn.");
     // Set the convolution descriptor
-#if CUDNN_VERSION_MIN(6,0,0)
+#if CUDNN_VERSION_MIN(6, 0, 0)
     CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
         conv_desc_,
         pad_t(),
@@ -268,6 +308,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
         1,
         CUDNN_CROSS_CORRELATION));
 #endif
+
 #if CUDNN_VERSION_MIN(7, 0, 0)
     // enable TensorCore math if desired
     enable_tensor_core_ &= TensorCoreAvailable();
@@ -275,7 +316,10 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
       CUDNN_ENFORCE(
           cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
     }
+    // set cuDNN groups if appropriate
+    CUDNN_ENFORCE(cudnnSetConvolutionGroupCount(conv_desc_, group_));
 #endif
+
     if (force_algo_[ALGO_DGRAD] >= 0) {
       bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD];
     } else if (deterministic_) {
@@ -331,24 +375,56 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
     VLOG(1) << "CuDNN workspace size: " << bwd_data_ws_size;
   }
 
+  const T* X_data = X.template data<T>();
+  const T* filter_data = filter.template data<T>();
+  T* Y_data = Y->template mutable_data<T>();
+
   // Now, actually run the computation.
   // Filter
+#if CUDNN_VERSION_MIN(7, 0, 0)
   cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
     CUDNN_ENFORCE(cudnnConvolutionBackwardData(
         state->cudnn_handle(),
         cudnnTypeWrapper<T>::kOne(),
         filter_desc_,
-        filter.template data<T>(),
+        filter_data,
         bottom_desc_,
-        X.template data<T>(),
+        X_data,
         conv_desc_,
         bwd_data_algo_,
         state->workspace().get(cudnn_ws_nbytes_),
         cudnn_ws_nbytes_,
         cudnnTypeWrapper<T>::kZero(),
         top_desc_,
-        Y->template mutable_data<T>()));
+        Y_data));
   });
+#else
+  const int X_HxW = H * W;
+  const int Y_HxW = H_out * W_out;
+  const int group_offset_X =
+      order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_;
+  const int group_offset_Y =
+      order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_;
+  const int group_offset_filter = filter.numel() / group_;
+  for (int i = 0; i < group_; ++i) {
+    cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
+      CUDNN_ENFORCE(
+          cudnnConvolutionBackwardData(state->cudnn_handle(),
+                                       cudnnTypeWrapper<T>::kOne(),
+                                       filter_desc_,
+                                       filter_data + i * group_offset_filter,
+                                       bottom_desc_,
+                                       X_data + i * group_offset_X;
+                                       conv_desc_,
+                                       bwd_data_algo_,
+                                       state->workspace().get(cudnn_ws_nbytes_),
+                                       cudnn_ws_nbytes_,
+                                       cudnnTypeWrapper<T_DX>::kZero(),
+                                       top_desc_,
+                                       Y_data + i * group_offset_Y));
+    });
+  }
+#endif
   // Bias
   if (InputSize() == 3) {
     CUDNN_ENFORCE(cudnnAddTensor(
@@ -357,7 +433,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
         bias_desc_,
         Input(BIAS).template data<T>(),
         cudnnTypeWrapper<T>::kOne(),
-        top_desc_,
+        top_desc_for_bias_,
         Y->template mutable_data<T>()));
   }
   // Done.
@@ -368,19 +444,19 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
 // consolidating them.
 template <typename T>
 bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
-  auto& X = Input(INPUT);
-  auto& filter = Input(FILTER);
-  auto& dY = Input(OUTPUT_GRAD);
+  const auto& X = Input(INPUT);
+  const auto& filter = Input(FILTER);
+  const auto& dY = Input(OUTPUT_GRAD);
 
   CAFFE_ENFORCE_EQ(X.dim(), 4);
   CAFFE_ENFORCE_EQ(filter.dim(), 4);
   int C = 0;
   switch (order_) {
     case StorageOrder::NHWC:
-      C = filter.dim32(3);
+      C = filter.dim32(3) * group_;
       break;
     case StorageOrder::NCHW:
-      C = filter.dim32(1);
+      C = filter.dim32(1) * group_;
       break;
     default:
       LOG(FATAL) << "Unknown storage order: " << order_;
@@ -398,7 +474,7 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
       CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
       CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
       CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
-      CAFFE_ENFORCE_EQ(filter.dim32(3), C);
+      CAFFE_ENFORCE_EQ(filter.dim32(3), C / group_);
       break;
     case StorageOrder::NCHW:
       N = X.dim32(0);
@@ -407,41 +483,42 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
       W = X.dim32(3);
       H_out = dY.dim32(2);
       W_out = dY.dim32(3);
-      CAFFE_ENFORCE_EQ(filter.dim32(1), C);
+      CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_);
       CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
       CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
       break;
     default:
       LOG(FATAL) << "Unknown storage order: " << order_;
   }
+  CAFFE_ENFORCE_EQ(M % group_, 0);
+
   // Since we only handle LegacyPadding::NOTSET, we don't need to
   // compute padding.
   auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
 
   // Set up the cudnn algorithms & workspace if necessary
-  bool input_changed = (X.sizes() != cudnn_input_dims_);
-  bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
+  const bool input_changed = (X.sizes() != cudnn_input_dims_);
+  const bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
   if (input_changed || filter_changed) {
     VLOG(1) << "Changing the cudnn descriptor configurations.";
     if (input_changed) {
       cudnn_input_dims_ = X.sizes().vec();
-      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
-          bottom_desc_,
-          GetCudnnTensorFormat(order_),
-          cudnnTypeWrapper<T>::type,
-          N,
-          M,
-          H,
-          W));
+      SetTensor4DDescriptorWithGroup(
+          cudnnTypeWrapper<T>::type, N, M, H, W, &bottom_desc_);
     }
     if (filter_changed) {
       cudnn_filter_dims_ = filter.sizes().vec();
+#if CUDNN_VERSION_MIN(7, 0, 0)
+      const int MM = M;
+#else
+      const int MM = M / group_;
+#endif
       CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
           filter_desc_,
           cudnnTypeWrapper<T>::type,
           GetCudnnTensorFormat(order_),
-          M,
-          C,
+          MM,
+          C / group_,
           kernel_h(),
           kernel_w()));
       if (!no_bias_) {
@@ -456,14 +533,19 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
       }
     }
     // Set the output
-    CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
-        top_desc_,
-        GetCudnnTensorFormat(order_),
-        cudnnTypeWrapper<T>::type,
-        N,
-        C,
-        H_out,
-        W_out));
+    SetTensor4DDescriptorWithGroup(
+        cudnnTypeWrapper<T>::type, N, C, H_out, W_out, &top_desc_);
+    if (!no_bias_) {
+      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
+          top_desc_for_bias_,
+          GetCudnnTensorFormat(order_),
+          cudnnTypeWrapper<T>::type,
+          N,
+          C,
+          H_out,
+          W_out));
+    }
+
     // Set the convolution descriptor
     CAFFE_ENFORCE_EQ(
         pad_t(),
@@ -475,7 +557,7 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
         pad_r(),
         "The current padding scheme leads to unequal padding on the left "
         "and right, which is not supported by cudnn.");
-#if CUDNN_VERSION_MIN(6,0,0)
+#if CUDNN_VERSION_MIN(6, 0, 0)
     CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
         conv_desc_,
         pad_t(),
@@ -504,6 +586,8 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
       CUDNN_ENFORCE(
           cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
     }
+    // set cuDNN groups if appropriate
+    CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
 #endif
     if (force_algo_[ALGO_WGRAD] >= 0) {
       bwd_filter_algo_ =
@@ -622,13 +706,14 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
     CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
         cudnn_wrapper_.inline_cudnn_handle(),
         cudnnTypeWrapper<T>::kOne(),
-        top_desc_,
+        top_desc_for_bias_,
         dY.template data<T>(),
         cudnnTypeWrapper<T>::kZero(),
         bias_desc_,
         dbias->template mutable_data<T>()));
   }
 
+#if CUDNN_VERSION_MIN(7, 0, 0)
   cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
     CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
         state->cudnn_handle(),
@@ -647,7 +732,6 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
 
     if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
       // Compute the gradient w.r.t. the input.
-
       auto* dX = Output(
           no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
           X.sizes(),
@@ -668,6 +752,55 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
           dX->template mutable_data<T>()));
     }
   });
+#else
+  const int X_HxW = H * W;
+  const int Y_HxW = H_out * W_out;
+  const int group_offset_X =
+      order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_;
+  const int group_offset_Y =
+      order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_;
+  const int group_offset_filter = filter.numel() / group_;
+  for (int i = 0; i < group_; ++i) {
+    cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
+      CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
+          state->cudnn_handle(),
+          cudnnTypeWrapper<T>::kOne(),
+          top_desc_,
+          dY.template data<T>() + i * group_offset_Y,
+          bottom_desc_,
+          X.template data<T>() + i * group_offset_X,
+          conv_desc_,
+          bwd_filter_algo_,
+          state->workspace().get(cudnn_ws_nbytes_),
+          cudnn_ws_nbytes_,
+          cudnnTypeWrapper<T>::kZero(),
+          filter_desc_,
+          dfilter->template mutable_data<T>() + i * group_offset_filter));
+      if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
+        // Compute the gradient w.r.t. the input.
+        auto* dX = Output(
+            no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
+            X.sizes(),
+            at::dtype<T>());
+        cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
+          CUDNN_ENFORCE(cudnnConvolutionForward(
+              state->cudnn_handle(),
+              cudnnTypeWrapper<T>::kOne(),
+              top_desc_,
+              dY.template data<T>() + i * group_offset_Y,
+              filter_desc_,
+              filter.template data<T>() + i * group_offset_filter,
+              conv_desc_,
+              algo_,
+              state->workspace().get(cudnn_ws_nbytes_),
+              cudnn_ws_nbytes_,
+              cudnnTypeWrapper<T>::kZero(),
+              bottom_desc_,
+              dX->template mutable_data<T>() + i * group_offset_X));
+        });
+      }
+  }
+#endif
   return true;
 }
 
index 41af81c..333f782 100644 (file)
@@ -3,11 +3,15 @@
 #ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_IMPL_H_
 #define CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_IMPL_H_
 
+#include "caffe2/operators/conv_transpose_op.h"
+
+#include <array>
+#include <vector>
+
 #include "caffe2/core/context.h"
 #include "caffe2/core/logging.h"
 #include "caffe2/core/operator.h"
 #include "caffe2/operators/conv_op_shared.h"
-#include "caffe2/operators/conv_transpose_op.h"
 #include "caffe2/operators/conv_transpose_unpool_op_base.h"
 #include "caffe2/utils/math.h"
 
@@ -17,551 +21,618 @@ namespace caffe2 {
 
 template <typename T, class Context>
 bool ConvTransposeOp<T, Context>::RunOnDeviceWithOrderNCHW() {
-  const Tensor& X = Input(INPUT);
-  auto& filter = Input(FILTER);
-  const int N = X.dim32(0), M = X.dim32(1), H = X.dim32(2), W = X.dim32(3);
-  CAFFE_ENFORCE(filter.dim() == 4, "filter must be 4D tensor");
-  CAFFE_ENFORCE(
-      filter.dim32(0) == M,
-      "filter number must be equal to input channel number");
-  const int C = filter.dim32(1);
-  CAFFE_ENFORCE(
-      filter.dim32(2) == this->kernel_h(),
+  const auto& X = Input(INPUT);
+  const auto& filter = Input(FILTER);
+  CAFFE_ENFORCE_EQ(X.dim(), 4, "Input must be 4D tensor");
+  CAFFE_ENFORCE_EQ(filter.dim(), 4, "filter must be 4D tensor");
+  const int N = X.dim32(0);
+  const int M = X.dim32(1);
+  const int H = X.dim32(2);
+  const int W = X.dim32(3);
+  const int G = group_;
+  CAFFE_ENFORCE_EQ(M, filter.dim32(0));
+  CAFFE_ENFORCE_EQ(
+      M % G, 0, "The number of input channels is not divisible by group.");
+  const int C = filter.dim32(1) * G;
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(2),
+      kernel_h(),
       "filter height must be equal to kernel height");
-  CAFFE_ENFORCE(
-      filter.dim32(3) == this->kernel_w(),
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(3),
+      this->kernel_w(),
       "filter width must be equal to kernel width");
-  auto sizes = ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
-  Tensor* Y = Output(0, sizes, at::dtype<T>());
+  const std::vector<std::int64_t> Y_dims =
+      ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
+  auto* Y = Output(0, Y_dims, at::dtype<T>());
 
-  const int kernel_dim = C * this->kernel_h() * this->kernel_w();
-  const int input_image_size = H * W;
-  const int output_image_size = Y->dim32(2) * Y->dim32(3);
+  if (N == 0) {
+    return true;
+  }
 
+  const int K_HxW = kernel_h() * kernel_w();
+  const int kernel_dim = C / G * K_HxW;
+  const int X_HxW = H * W;
+  const int Y_HxW = Y->dim32(2) * Y->dim32(3);
+
+  const T* X_data = X.template data<T>();
+  const T* filter_data = filter.template data<T>();
+  const T* bias_data = nullptr;
   if (InputSize() == 3) {
     auto& bias = Input(BIAS);
-    CAFFE_ENFORCE(bias.dim() == 1, "bias must be 1D tensor");
-    CAFFE_ENFORCE(
-        bias.dim32(0) == C,
+    CAFFE_ENFORCE_EQ(bias.dim(), 1, "bias must be 1D tensor");
+    CAFFE_ENFORCE_EQ(
+        bias.dim32(0),
+        C,
         "bias dimension must be equal to output channel number");
-    ReinitializeTensor(
-        &bias_multiplier_,
-        {1, output_image_size},
-        at::dtype<T>().device(Context::GetDeviceType()));
-      T* bm_data = bias_multiplier_.template mutable_data<T>();
-      math::Set<T, Context>(
-          output_image_size,
-          static_cast<T>(1),
-          bm_data,
-          &context_);
+    bias_data = bias.template data<T>();
   }
+  T* Y_data = Y->template mutable_data<T>();
 
-  const T* Xdata = X.template data<T>();
-  const T* filter_data = filter.template data<T>();
-  T* Ydata = Y->template mutable_data<T>();
+  const std::vector<std::int64_t> buffer_shape = {
+      C, kernel_h(), kernel_w(), H, W};
 
-  auto f = [&](Tensor* col_buffer) {
-    ReinitializeTensor(col_buffer, vector<int64_t>{C, this->kernel_h(), this->kernel_w(), H, W}, at::dtype<T>().device(Context::GetDeviceType()));
+  const auto func = [&](Tensor* col_buffer) {
+    ReinitializeTensor(
+        col_buffer,
+        buffer_shape,
+        at::dtype<T>().device(Context::GetDeviceType()));
     T* col_buffer_data = col_buffer->template mutable_data<T>();
-    for (auto image_id = 0; image_id < N; ++image_id) {
+    for (int image_id = 0; image_id < N; ++image_id) {
       // Weight term
-      math::Gemm<T, Context>(
-          CblasTrans,
-          CblasNoTrans,
-          kernel_dim,
-          input_image_size,
-          M,
-          1,
-          filter_data,
-          Xdata,
-          0,
-          col_buffer_data,
-          &context_);
+      if (G == 1) {
+        math::Gemm<T, Context>(
+            CblasTrans,
+            CblasNoTrans,
+            kernel_dim,
+            X_HxW,
+            M,
+            1.0f,
+            filter_data,
+            X_data + image_id * M * X_HxW,
+            0.0f,
+            col_buffer_data,
+            &context_);
+      } else {
+        math::GemmStridedBatched<T, Context>(
+            CblasTrans,
+            CblasNoTrans,
+            G,
+            kernel_dim,
+            X_HxW,
+            M / G,
+            1.0f,
+            filter_data,
+            M / G * kernel_dim,
+            X_data + image_id * M * X_HxW,
+            M / G * X_HxW,
+            0.0f,
+            col_buffer_data,
+            col_buffer->numel() / G,
+            &context_);
+      }
 
       // Col2Im
       math::Col2Im<T, Context, StorageOrder::NCHW>(
           C,
           Y->dim32(2),
           Y->dim32(3),
-          this->kernel_h(),
-          this->kernel_w(),
+          kernel_h(),
+          kernel_w(),
           1,
           1,
-          this->pad_t(),
-          this->pad_l(),
-          this->pad_b(),
-          this->pad_r(),
-          this->stride_h(),
-          this->stride_w(),
+          pad_t(),
+          pad_l(),
+          pad_b(),
+          pad_r(),
+          stride_h(),
+          stride_w(),
           col_buffer_data,
-          Ydata,
+          Y_data + image_id * C * Y_HxW,
           &context_);
 
-      // Bias term
-      if (InputSize() == 3) {
-        const T* bias_data = Input(BIAS).template data<T>();
-        const T* bm_data = bias_multiplier_.template data<T>();
-#if !defined(__ARM_NEON__) && !defined(__ARM_NEON)
-        math::Gemm<T, Context>(
-            CblasNoTrans,
-            CblasNoTrans,
-            C,
-            output_image_size,
-            1,
-            1,
-            bias_data,
-            bm_data,
-            1,
-            Ydata,
-            &context_);
-#else
+      if (bias_data != nullptr) {
+        // Bias term
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
         math::BiasCHW<T, Context>(
             bias_data,
-            bm_data,
+            nullptr,
             C,
-            output_image_size,
-            Ydata,
+            Y_HxW,
+            Y_data + image_id * C * Y_HxW,
             &context_);
 #endif // !defined(__ARM_NEON__) && !defined(__ARM_NEON)
       }
-
-      Xdata += M * H * W;
-      Ydata += Y->numel() / Y->dim32(0);
+    }
+    if (bias_data != nullptr) {
+#if !defined(__ARM_NEON__) && !defined(__ARM_NEON)
+      // Bias term
+      const std::array<int, 3> Y_dims = {N, C, Y_HxW};
+      const std::array<int, 3> b_dims = {1, C, 1};
+      math::Add<T, Context>(
+          3,
+          Y_dims.data(),
+          3,
+          b_dims.data(),
+          Y_data,
+          bias_data,
+          Y_data,
+          &context_);
+#endif
     }
   };
+
   if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
-    runWithSharedBuffer<Context>(ws_, f);
+    runWithSharedBuffer<Context>(ws_, func);
   } else {
-    f(&col_buffer_);
+    func(&col_buffer_);
   }
   return true;
 }
 
 template <typename T, class Context>
 bool ConvTransposeOp<T, Context>::RunOnDeviceWithOrderNHWC() {
-  const Tensor& X = Input(INPUT);
+  const auto& X = Input(INPUT);
   auto& filter = Input(FILTER);
-  const auto N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), M = X.dim32(3);
-  CAFFE_ENFORCE(filter.dim() == 4, "filter must be 4D tensor");
-  CAFFE_ENFORCE(
-      filter.dim32(0) == M,
+  CAFFE_ENFORCE_EQ(filter.dim(), 4, "filter must be 4D tensor");
+  const int N = X.dim32(0);
+  const int H = X.dim32(1);
+  const int W = X.dim32(2);
+  const int M = X.dim32(3);
+  const int G = group_;
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(0),
+      M,
       "filter number must be equal to input channel number");
-  CAFFE_ENFORCE(
-      filter.dim32(1) == this->kernel_h(),
+  CAFFE_ENFORCE_EQ(
+      M % G, 0, "The number of input channels is not divisible by group.");
+  const int C = filter.dim32(3) * G;
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(1),
+      kernel_h(),
       "filter height must be equal to kernel height");
-  CAFFE_ENFORCE(
-      filter.dim32(2) == this->kernel_w(),
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(2),
+      kernel_w(),
       "filter width must be equal to kernel width");
-  const int C = filter.dim32(3);
-  auto sizes = ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
-  Tensor* Y = Output(0, sizes, at::dtype<T>());
 
-  const auto kernel_dim = C * this->kernel_h() * this->kernel_w();
-  const auto input_image_size = H * W;
-  const auto output_image_size = Y->dim32(1) * Y->dim32(2);
+  const std::vector<std::int64_t> Y_dims =
+      ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
+  auto* Y = Output(0, Y_dims, at::dtype<T>());
+  if (N == 0) {
+    return true;
+  }
+
+  const int K_HxW = kernel_h() * kernel_w();
+  const int kernel_dim = C / G * K_HxW;
+  const int X_HxW = H * W;
+  const int Y_HxW = Y->dim32(1) * Y->dim32(2);
 
+  const T* X_data = X.template data<T>();
+  const T* filter_data = filter.template data<T>();
+  const T* bias_data = nullptr;
   if (InputSize() == 3) {
     auto& bias = Input(BIAS);
-    CAFFE_ENFORCE(bias.dim() == 1, "bias must be 1D tensor");
-    CAFFE_ENFORCE(
-        bias.dim32(0) == C,
+    CAFFE_ENFORCE_EQ(bias.dim(), 1, "bias must be 1D tensor");
+    CAFFE_ENFORCE_EQ(
+        bias.dim32(0),
+        C,
         "bias dimension must be equal to output channel number");
-    // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size
-    ReinitializeTensor(
-        &bias_multiplier_,
-        {1, output_image_size},
-        at::dtype<T>().device(Context::GetDeviceType()));
-      T* bm_data = bias_multiplier_.template mutable_data<T>();
-      math::Set<T, Context>(
-          output_image_size,
-          static_cast<T>(1),
-          bm_data,
-          &context_);
+    bias_data = bias.template data<T>();
   }
-  const T* Xdata = X.template data<T>();
-  const T* filter_data = filter.template data<T>();
-  T* Ydata = Y->template mutable_data<T>();
+  T* Y_data = Y->template mutable_data<T>();
 
-  auto f = [&](Tensor* /*col_buffer*/) {
+  const std::vector<std::int64_t> buffer_shape = {
+      G, H, W, kernel_h(), kernel_w(), C / G};
+  const auto func = [&](Tensor* /*col_buffer*/) {
     ReinitializeTensor(
         &col_buffer_,
-        vector<int64_t>{H, W, this->kernel_h(), this->kernel_w(), C},
+        buffer_shape,
         at::dtype<T>().device(Context::GetDeviceType()));
     T* col_buffer_data = col_buffer_.template mutable_data<T>();
-    for (auto image_id = 0; image_id < N; ++image_id) {
+    for (int image_id = 0; image_id < N; ++image_id) {
       // Weight term
-      math::Gemm<T, Context>(
-          CblasNoTrans,
-          CblasNoTrans,
-          input_image_size,
-          kernel_dim,
-          M,
-          1,
-          Xdata,
-          filter_data,
-          0,
-          col_buffer_data,
-          &context_);
+      if (G == 1) {
+        math::Gemm<T, Context>(
+            CblasNoTrans,
+            CblasNoTrans,
+            X_HxW,
+            kernel_dim,
+            M,
+            1.0f,
+            X_data + image_id * M * X_HxW,
+            filter_data,
+            0.0f,
+            col_buffer_data,
+            &context_);
+      } else {
+        for (int group_id = 0; group_id < G; ++group_id) {
+          math::GemmEx<T, Context>(
+              CblasNoTrans,
+              CblasNoTrans,
+              X_HxW,
+              kernel_dim,
+              M / G,
+              1.0f,
+              X_data + image_id * M * X_HxW + group_id * M / G,
+              M,
+              filter_data + group_id * M / G * kernel_dim,
+              kernel_dim,
+              0.0f,
+              col_buffer_data + group_id * kernel_dim,
+              G * kernel_dim,
+              &context_);
+        }
+      }
       // Col2Im
       math::Col2Im<T, Context, StorageOrder::NHWC>(
           C,
           Y->dim32(1),
           Y->dim32(2),
-          this->kernel_h(),
-          this->kernel_w(),
+          kernel_h(),
+          kernel_w(),
           1,
           1,
-          this->pad_t(),
-          this->pad_l(),
-          this->pad_b(),
-          this->pad_r(),
-          this->stride_h(),
-          this->stride_w(),
+          pad_t(),
+          pad_l(),
+          pad_b(),
+          pad_r(),
+          stride_h(),
+          stride_w(),
           col_buffer_data,
-          Ydata,
-          &context_);
+          Y_data + image_id * C * Y_HxW,
+          &context_,
+          G);
+    }
+    if (bias_data != nullptr) {
       // Bias term
-      if (InputSize() == 3) {
-        const T* bm_data = bias_multiplier_.template data<T>();
-        const T* bias_data = Input(BIAS).template data<T>();
-        math::Gemm<T, Context>(
-            CblasNoTrans,
-            CblasNoTrans,
-            output_image_size,
-            C,
-            1,
-            1,
-            bm_data,
-            bias_data,
-            1,
-            Ydata,
-            &context_);
-      }
-      Xdata += M * H * W;
-      Ydata += Y->numel() / Y->dim32(0);
+      const std::array<int, 2> Y_dims = {N * Y_HxW, C};
+      const std::array<int, 2> b_dims = {1, C};
+      math::Add<T, Context>(
+          2,
+          Y_dims.data(),
+          2,
+          b_dims.data(),
+          Y_data,
+          bias_data,
+          Y_data,
+          &context_);
     }
   };
+
   if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
-    runWithSharedBuffer<Context>(ws_, f);
+    runWithSharedBuffer<Context>(ws_, func);
   } else {
-    f(&col_buffer_);
+    func(&col_buffer_);
   }
   return true;
 }
 
 template <typename T, class Context>
 bool ConvTransposeGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
-  auto& X = Input(INPUT);
-  auto& filter = Input(FILTER);
-  auto& dY = Input(OUTPUT_GRAD);
-
-  const int N = X.dim32(0), M = X.dim32(1), H = X.dim32(2), W = X.dim32(3);
-  // We only handle LegacyPadding::NOTSET case and ignore cases of
-  // LegacyPadding::VALID and LegacyPadding::SAME
-  // Thus, we don't need to manually compute padding values
-  // We simply use the values from the user
-  CAFFE_ENFORCE(filter.dim() == 4);
-  const int C = filter.dim32(1);
-  CAFFE_ENFORCE(
-      filter.dim32(2) == this->kernel_h(),
+  const auto& X = Input(INPUT);
+  const auto& filter = Input(FILTER);
+  const auto& dY = Input(OUTPUT_GRAD);
+  CAFFE_ENFORCE_EQ(filter.dim(), 4);
+  const int N = X.dim32(0);
+  const int M = X.dim32(1);
+  const int H = X.dim32(2);
+  const int W = X.dim32(3);
+  const int G = group_;
+  CAFFE_ENFORCE_EQ(M, filter.dim32(0));
+  CAFFE_ENFORCE_EQ(
+      M % G, 0, "The number of input channels is not divisible by group.");
+  const int C = filter.dim32(1) * G;
+  CAFFE_ENFORCE_EQ(C, dY.dim32(1));
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(2),
+      kernel_h(),
       "filter height must be equal to kernel height");
-  CAFFE_ENFORCE(
-      filter.dim32(3) == this->kernel_w(),
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(3),
+      this->kernel_w(),
       "filter width must be equal to kernel width");
+
+  const int K_HxW = kernel_h() * kernel_w();
+  const int kernel_dim = C / G * K_HxW;
+  const int X_HxW = H * W;
+  const int Y_HxW = dY.dim32(2) * dY.dim32(3);
   auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
 
-  const int kernel_dim = C * this->kernel_h() * this->kernel_w();
-  const int output_image_size = dY.dim32(2) * dY.dim32(3);
-  // The col buffer is stored in CHW order as well
-  ReinitializeTensor(
-      &col_buffer_,
-      vector<int64_t>{C, this->kernel_h(), this->kernel_w(), H, W},
-      at::dtype<T>().device(Context::GetDeviceType()));
-  if (!no_bias_) {
-    auto* dbias = Output(BIAS_OR_INPUT_GRAD);
-    dbias->Resize(C);
-    // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size
-    ReinitializeTensor(
-        &bias_multiplier_,
-        {1, output_image_size},
-        at::dtype<T>().device(Context::GetDeviceType()));
-    T* bm_data = bias_multiplier_.template mutable_data<T>();
-    math::Set<T, Context>(
-        output_image_size,
-        static_cast<T>(1),
-        bm_data,
-        &context_);
-  }
-  T* col_buffer_data = col_buffer_.template mutable_data<T>();
-  const T* Xdata = X.template data<T>();
+  const T* X_data = X.template data<T>();
   const T* filter_data = filter.template data<T>();
-  const T* dYdata = dY.template data<T>();
+  const T* dY_data = dY.template data<T>();
   T* dfilter_data = dfilter->template mutable_data<T>();
-  // Pre-setting the gradients to zero
-  math::Set<T, Context>(dfilter->numel(), 0, dfilter_data, &context_);
+  T* dbias_data = nullptr;
+  T* dX_data = nullptr;
   if (!no_bias_) {
-    auto* dbias = Output(BIAS_OR_INPUT_GRAD);
-    T* dbias_data = dbias->template mutable_data<T>();
-    math::Set<T, Context>(dbias->numel(), 0, dbias_data, &context_);
+    auto* dbias = Output(BIAS_OR_INPUT_GRAD, {C}, at::dtype<T>());
+    dbias_data = dbias->template mutable_data<T>();
   }
-  for (auto image_id = 0; image_id < N; ++image_id) {
+  const bool compute_dX =
+      (OutputSize() == 3) || (no_bias_ && (OutputSize() == 2));
+  if (compute_dX) {
+    auto* dX = Output(
+        no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype<T>());
+    dX_data = dX->template mutable_data<T>();
+  }
+  math::Set<T, Context>(filter.numel(), T(0), dfilter_data, &context_);
+
+  if (N == 0) {
+    math::Set<T, Context>(C, T(0), dbias_data, &context_);
+    return true;
+  }
+
+  ReinitializeTensor(
+      &col_buffer_,
+      std::vector<std::int64_t>{C, kernel_h(), kernel_w(), H, W},
+      at::dtype<T>().device(Context::GetDeviceType()));
+  T* col_buffer_data = col_buffer_.template mutable_data<T>();
+
+  for (int image_id = 0; image_id < N; ++image_id) {
     // gradient w.r.t. filters. Im2Col followed by Gemm
     // Im2Col.
     math::Im2Col<T, Context, StorageOrder::NCHW>(
         C,
         dY.dim32(2),
         dY.dim32(3),
-        this->kernel_h(),
-        this->kernel_w(),
+        kernel_h(),
+        kernel_w(),
         1,
         1,
-        this->pad_t(),
-        this->pad_l(),
-        this->pad_b(),
-        this->pad_r(),
-        this->stride_h(),
-        this->stride_w(),
-        dYdata,
+        pad_t(),
+        pad_l(),
+        pad_b(),
+        pad_r(),
+        stride_h(),
+        stride_w(),
+        dY_data + image_id * C * Y_HxW,
         col_buffer_data,
         &context_);
     // Gemm
-    math::Gemm<T, Context>(
-        CblasNoTrans,
-        CblasTrans,
-        M,
-        kernel_dim,
-        H * W,
-        1,
-        Xdata,
-        col_buffer_data,
-        1,
-        dfilter_data,
-        &context_);
-    // gradient w.r.t. bias
-    if (!no_bias_) {
-      const T* bm_data = bias_multiplier_.template data<T>();
-      T* input_grad_data = Output(BIAS_OR_INPUT_GRAD)->template mutable_data<T>();
+    if (G == 1) {
       math::Gemm<T, Context>(
           CblasNoTrans,
-          CblasNoTrans,
-          C,
-          1,
-          output_image_size,
-          1,
-          dYdata,
-          bm_data,
-          1,
-          input_grad_data,
-          &context_);
-    }
-    dYdata += dY.numel() / dY.dim32(0);
-    Xdata += X.numel() / X.dim32(0);
-  }
-  if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
-    // Compute gradients w.r.t. the input
-    // Since we have changed dYdata in the above loop, we will need to reset.
-    dYdata = dY.template data<T>();
-
-    auto* dX = Output(
-        no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype<T>());
-    T* dXdata = dX->template mutable_data<T>();
-    for (auto image_id = 0; image_id < N; ++image_id) {
-      // Im2Col.
-      // TODO(zyan3): Probably duplicate work as in gradient computation
-      // w.r.t filters
-      math::Im2Col<T, Context, StorageOrder::NCHW>(
-          C,
-          dY.dim32(2),
-          dY.dim32(3),
-          this->kernel_h(),
-          this->kernel_w(),
-          1,
-          1,
-          this->pad_t(),
-          this->pad_l(),
-          this->pad_b(),
-          this->pad_r(),
-          this->stride_h(),
-          this->stride_w(),
-          dYdata,
+          CblasTrans,
+          M,
+          kernel_dim,
+          X_HxW,
+          1.0f,
+          X_data + image_id * M * X_HxW,
           col_buffer_data,
+          1.0f,
+          dfilter_data,
           &context_);
-      // Gemm
-      math::Gemm<T, Context>(
+    } else {
+      math::GemmStridedBatched<T, Context>(
           CblasNoTrans,
-          CblasNoTrans,
-          M,
-          H * W,
+          CblasTrans,
+          G,
+          M / G,
           kernel_dim,
-          1,
-          filter_data,
+          X_HxW,
+          1.0f,
+          X_data + image_id * M * X_HxW,
+          M / G * X_HxW,
           col_buffer_data,
-          0,
-          dXdata,
+          col_buffer_.numel() / G,
+          1.0f,
+          dfilter_data,
+          M / G * kernel_dim,
           &context_);
-      dYdata += dY.numel() / dY.dim32(0);
-      dXdata += X.numel() / X.dim32(0);
+    }
+
+    if (dX_data != nullptr) {
+      // Compute gradients w.r.t. the input
+      if (G == 1) {
+        math::Gemm<T, Context>(
+            CblasNoTrans,
+            CblasNoTrans,
+            M,
+            X_HxW,
+            kernel_dim,
+            1.0f,
+            filter_data,
+            col_buffer_data,
+            0.0f,
+            dX_data + image_id * M * X_HxW,
+            &context_);
+      } else {
+        math::GemmStridedBatched<T, Context>(
+            CblasNoTrans,
+            CblasNoTrans,
+            G,
+            M / G,
+            X_HxW,
+            kernel_dim,
+            1.0f,
+            filter_data,
+            M / G * kernel_dim,
+            col_buffer_data,
+            col_buffer_.numel() / G,
+            0.0f,
+            dX_data + image_id * M * X_HxW,
+            M / G * X_HxW,
+            &context_);
+      }
     }
   }
+
+  if (dbias_data != nullptr) {
+    // gradient w.r.t. bias
+    const std::array<int, 3> Y_dims = {N, C, Y_HxW};
+    const std::array<int, 3> b_dims = {1, C, 1};
+    math::ReduceSum<T, Context>(
+        3, Y_dims.data(), b_dims.data(), T(1), dY_data, dbias_data, &context_);
+  }
+
   return true;
 }
 
 template <typename T, class Context>
 bool ConvTransposeGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
-  auto& X = Input(INPUT);
-  auto& filter = Input(FILTER);
-  auto& dY = Input(OUTPUT_GRAD);
-
-  const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), M = X.dim32(3);
-  // We only handle LegacyPadding::NOTSET case and ignore cases of
-  // LegacyPadding::VALID and LegacyPadding::SAME
-  // Thus, we don't need to manually compute padding values
-  // We simply use the values from the user
-  CAFFE_ENFORCE(filter.dim() == 4, "filter must be 4D tensor");
-  CAFFE_ENFORCE(
-      filter.dim32(1) == this->kernel_h(),
+  const auto& X = Input(INPUT);
+  const auto& filter = Input(FILTER);
+  const auto& dY = Input(OUTPUT_GRAD);
+  CAFFE_ENFORCE_EQ(filter.dim(), 4);
+  const int N = X.dim32(0);
+  const int H = X.dim32(1);
+  const int W = X.dim32(2);
+  const int M = X.dim32(3);
+  const int G = group_;
+  CAFFE_ENFORCE_EQ(M, filter.dim32(0));
+  CAFFE_ENFORCE_EQ(
+      M % G, 0, "The number of input channels is not divisible by group.");
+  const int C = filter.dim32(3) * G;
+  CAFFE_ENFORCE_EQ(C, dY.dim32(3));
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(1),
+      kernel_h(),
       "filter height must be equal to kernel height");
-  CAFFE_ENFORCE(
-      filter.dim32(2) == this->kernel_w(),
+  CAFFE_ENFORCE_EQ(
+      filter.dim32(2),
+      this->kernel_w(),
       "filter width must be equal to kernel width");
-  const int C = filter.dim32(3);
+  CAFFE_ENFORCE_EQ(dY.dim32(3), C);
+
+  const int K_HxW = kernel_h() * kernel_w();
+  const int kernel_dim = C / G * K_HxW;
+  const int X_HxW = H * W;
+  const int Y_HxW = dY.dim32(1) * dY.dim32(2);
   auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
 
-  const int kernel_dim = C * this->kernel_h() * this->kernel_w();
-  const int output_image_size = dY.dim32(1) * dY.dim32(2);
-  // The col buffer is stored in HWC order as well
-  ReinitializeTensor(
-      &col_buffer_,
-      vector<int64_t>{H, W, this->kernel_h(), this->kernel_w(), C},
-      at::dtype<T>().device(Context::GetDeviceType()));
-  if (!no_bias_) {
-    auto* dbias = Output(BIAS_OR_INPUT_GRAD);
-    dbias->Resize(C);
-    // TODO(jerryzh): is it OK to remove the check of whether numel is output_image_size
-    ReinitializeTensor(
-        &bias_multiplier_,
-        {1, output_image_size},
-        at::dtype<T>().device(Context::GetDeviceType()));
-    T* bm_data = bias_multiplier_.template mutable_data<T>();
-    math::Set<T, Context>(
-        output_image_size,
-        static_cast<T>(1),
-        bm_data,
-        &context_);
-  }
-  T* col_buffer_data = col_buffer_.template mutable_data<T>();
-  const T* Xdata = X.template data<T>();
+  const T* X_data = X.template data<T>();
   const T* filter_data = filter.template data<T>();
-  const T* dYdata = dY.template data<T>();
+  const T* dY_data = dY.template data<T>();
   T* dfilter_data = dfilter->template mutable_data<T>();
-  // Pre-setting the gradients to zero
-  math::Set<T, Context>(dfilter->numel(), 0, dfilter_data, &context_);
+  T* dbias_data = nullptr;
+  T* dX_data = nullptr;
   if (!no_bias_) {
-    auto* dbias = Output(BIAS_OR_INPUT_GRAD);
-    T* dbias_data = dbias->template mutable_data<T>();
-    math::Set<T, Context>(dbias->numel(), 0, dbias_data, &context_);
+    auto* dbias = Output(BIAS_OR_INPUT_GRAD, {C}, at::dtype<T>());
+    dbias_data = dbias->template mutable_data<T>();
+  }
+  const bool compute_dX =
+      (OutputSize() == 3) || (no_bias_ && (OutputSize() == 2));
+  if (compute_dX) {
+    auto* dX = Output(
+        no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype<T>());
+    dX_data = dX->template mutable_data<T>();
   }
-  for (auto image_id = 0; image_id < N; ++image_id) {
+  math::Set<T, Context>(filter.numel(), T(0), dfilter_data, &context_);
+
+  if (N == 0) {
+    math::Set<T, Context>(C, T(0), dbias_data, &context_);
+    return true;
+  }
+
+  ReinitializeTensor(
+      &col_buffer_,
+      std::vector<std::int64_t>{C, kernel_h(), kernel_w(), H, W},
+      at::dtype<T>().device(Context::GetDeviceType()));
+  T* col_buffer_data = col_buffer_.template mutable_data<T>();
+
+  for (int image_id = 0; image_id < N; ++image_id) {
     // gradient w.r.t. filters. Im2Col followed by Gemm
     // Im2Col.
     math::Im2Col<T, Context, StorageOrder::NHWC>(
         C,
         dY.dim32(1),
         dY.dim32(2),
-        this->kernel_h(),
-        this->kernel_w(),
+        kernel_h(),
+        kernel_w(),
         1,
         1,
-        this->pad_t(),
-        this->pad_l(),
-        this->pad_b(),
-        this->pad_r(),
-        this->stride_h(),
-        this->stride_w(),
-        dYdata,
+        pad_t(),
+        pad_l(),
+        pad_b(),
+        pad_r(),
+        stride_h(),
+        stride_w(),
+        dY_data + image_id * C * Y_HxW,
         col_buffer_data,
-        &context_);
+        &context_,
+        G);
     // Gemm
-    math::Gemm<T, Context>(
-        CblasTrans,
-        CblasNoTrans,
-        M,
-        kernel_dim,
-        H * W,
-        1,
-        Xdata,
-        col_buffer_data,
-        1,
-        dfilter_data,
-        &context_);
-    // gradients w.r.t. bias
-    if (!no_bias_) {
-      const T* bm_data = bias_multiplier_.template data<T>();
-      T* input_grad_data = Output(BIAS_OR_INPUT_GRAD)->template mutable_data<T>();
+    if (G == 1) {
       math::Gemm<T, Context>(
           CblasTrans,
           CblasNoTrans,
-          C,
-          1,
-          output_image_size,
-          1,
-          dYdata,
-          bm_data,
-          1,
-          input_grad_data,
-          &context_);
-    }
-    dYdata += dY.numel() / dY.dim32(0);
-    Xdata += X.numel() / X.dim32(0);
-  }
-  if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
-    // Compute gradients w.r.t. the input
-    // Since we have changed dYdata in the above loop, we will need to reset.
-    dYdata = dY.template data<T>();
-
-    auto* dX = Output(
-        no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype<T>());
-    T* dXdata = dX->template mutable_data<T>();
-    for (auto image_id = 0; image_id < N; ++image_id) {
-      // Im2Col.
-      // TODO(zyan3): Probably duplicate work as in gradient computation
-      // w.r.t filters
-      math::Im2Col<T, Context, StorageOrder::NHWC>(
-          C,
-          dY.dim32(1),
-          dY.dim32(2),
-          this->kernel_h(),
-          this->kernel_w(),
-          1,
-          1,
-          this->pad_t(),
-          this->pad_l(),
-          this->pad_b(),
-          this->pad_r(),
-          this->stride_h(),
-          this->stride_w(),
-          dYdata,
-          col_buffer_data,
-          &context_);
-      // Gemm
-      math::Gemm<T, Context>(
-          CblasNoTrans,
-          CblasTrans,
-          H * W,
           M,
           kernel_dim,
-          1,
+          X_HxW,
+          1.0f,
+          X_data + image_id * M * X_HxW,
           col_buffer_data,
-          filter_data,
-          0,
-          dXdata,
+          1.0f,
+          dfilter_data,
           &context_);
-      dYdata += dY.numel() / dY.dim32(0);
-      dXdata += X.numel() / X.dim32(0);
+    } else {
+      for (int group_id = 0; group_id < G; ++group_id) {
+        math::GemmEx<T, Context>(
+            CblasTrans,
+            CblasNoTrans,
+            M / G,
+            kernel_dim,
+            X_HxW,
+            1.0f,
+            X_data + image_id * M * X_HxW + group_id * M / G,
+            M,
+            col_buffer_data + group_id * kernel_dim,
+            G * kernel_dim,
+            1.0f,
+            dfilter_data + group_id * M / G * kernel_dim,
+            kernel_dim,
+            &context_);
+      }
+    }
+
+    if (dX_data != nullptr) {
+      // Compute gradients w.r.t. the input
+      if (G == 1) {
+        math::Gemm<T, Context>(
+            CblasNoTrans,
+            CblasTrans,
+            X_HxW,
+            M,
+            kernel_dim,
+            1.0f,
+            col_buffer_data,
+            filter_data,
+            0.0f,
+            dX_data + image_id * M * X_HxW,
+            &context_);
+      } else {
+        for (int group_id = 0; group_id < G; ++group_id) {
+          math::GemmEx<T, Context>(
+              CblasNoTrans,
+              CblasTrans,
+              X_HxW,
+              M / G,
+              kernel_dim,
+              1.0f,
+              col_buffer_data + group_id * kernel_dim,
+              G * kernel_dim,
+              filter_data + group_id * M / G * kernel_dim,
+              kernel_dim,
+              0.0f,
+              dX_data + image_id * M * X_HxW + group_id * M / G,
+              M,
+              &context_);
+        }
+      }
     }
   }
+
+  if (dbias_data != nullptr) {
+    const std::array<int, 2> Y_dims = {N * Y_HxW, C};
+    const std::array<int, 2> b_dims = {1, C};
+    math::ReduceSum<T, Context>(
+        2, Y_dims.data(), b_dims.data(), T(1), dY_data, dbias_data, &context_);
+  }
+
   return true;
 }
 
 } // namespace caffe2
+
 #endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_IMPL_H_
index 7ebfda7..c98b3ba 100644 (file)
@@ -17,7 +17,9 @@ template <class Context>
 class ConvTransposeUnpoolBase : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
-  explicit ConvTransposeUnpoolBase(const OperatorDef& operator_def, Workspace* ws)
+  explicit ConvTransposeUnpoolBase(
+      const OperatorDef& operator_def,
+      Workspace* ws)
       : Operator<Context>(operator_def, ws),
         legacy_pad_(
             static_cast<LegacyPadding>(this->template GetSingleArgument<int>(
@@ -27,6 +29,7 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
         stride_(this->template GetRepeatedArgument<int>("strides")),
         pads_(this->template GetRepeatedArgument<int>("pads")),
         adj_(this->template GetRepeatedArgument<int>("adjs")),
+        group_(this->template GetSingleArgument<int>("group", 1)),
         order_(StringToStorageOrder(
             this->template GetSingleArgument<string>("order", "NCHW"))),
         shared_buffer_(
@@ -206,19 +209,7 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
 
   virtual ~ConvTransposeUnpoolBase() {}
 
- private:
-  LegacyPadding legacy_pad_;
-  int pad_;
-
  protected:
-  vector<int> kernel_;
-  vector<int> stride_;
-  vector<int> pads_;
-  vector<int> adj_;
-  StorageOrder order_;
-  bool shared_buffer_;
-  Workspace* ws_;
-
   // Accessors for 2D conv params.
 
   inline int pad_t() const {
@@ -289,14 +280,35 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
         break;
     }
   }
+
+  LegacyPadding legacy_pad_;
+  int pad_;
+
+  std::vector<int> kernel_;
+  std::vector<int> stride_;
+  std::vector<int> pads_;
+  std::vector<int> adj_;
+  int group_;
+  StorageOrder order_;
+  bool shared_buffer_;
+  Workspace* ws_;
 };
 
 #define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \
   USE_OPERATOR_FUNCTIONS(Context);                        \
   using ConvTransposeUnpoolBase<Context>::kernel_;        \
+  using ConvTransposeUnpoolBase<Context>::kernel_h;       \
+  using ConvTransposeUnpoolBase<Context>::kernel_w;       \
   using ConvTransposeUnpoolBase<Context>::stride_;        \
+  using ConvTransposeUnpoolBase<Context>::stride_h;       \
+  using ConvTransposeUnpoolBase<Context>::stride_w;       \
   using ConvTransposeUnpoolBase<Context>::pads_;          \
+  using ConvTransposeUnpoolBase<Context>::pad_t;          \
+  using ConvTransposeUnpoolBase<Context>::pad_l;          \
+  using ConvTransposeUnpoolBase<Context>::pad_b;          \
+  using ConvTransposeUnpoolBase<Context>::pad_r;          \
   using ConvTransposeUnpoolBase<Context>::adj_;           \
+  using ConvTransposeUnpoolBase<Context>::group_;         \
   using ConvTransposeUnpoolBase<Context>::order_;         \
   using ConvTransposeUnpoolBase<Context>::shared_buffer_; \
   using ConvTransposeUnpoolBase<Context>::ws_
index 2a32f9a..272ac3a 100644 (file)
@@ -6,6 +6,7 @@ import numpy as np
 from hypothesis import assume, given, settings
 import hypothesis.strategies as st
 
+from caffe2.proto import caffe2_pb2
 from caffe2.python import core, utils
 import caffe2.python.hypothesis_test_util as hu
 import caffe2.python.hip_test_util as hiputl
@@ -360,6 +361,68 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
         for i in outputs_to_check:
             self.assertGradientChecks(gc, op, inputs, i, [0])
 
+    @given(stride=st.integers(1, 3),
+           pad=st.integers(0, 3),
+           kernel=st.integers(1, 3),
+           adj=st.integers(0, 2),
+           size=st.integers(7, 10),
+           input_channels=st.integers(1, 8),
+           output_channels=st.integers(1, 8),
+           batch_size=st.integers(1, 4),
+           group=st.integers(1, 4),
+           order=st.sampled_from(["NCHW", "NHWC"]),
+           engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
+           shared_buffer=st.booleans(),
+           use_bias=st.booleans(),
+           **hu.gcs)
+    def test_convolution_transpose_with_group(
+            self, stride, pad, kernel, adj, size, input_channels,
+            output_channels, batch_size, group, order, engine, shared_buffer,
+            use_bias, gc, dc):
+        assume(adj < stride)
+        # TODO: Group conv_transpose in NHWC not implemented for GPU yet.
+        assume(group == 1 or order == "NCHW" or
+               gc.device_type == caffe2_pb2.CPU)
+        if group != 1 and order == "NHWC":
+            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
+
+        if hiputl.run_in_hip(gc, dc) and order == "NHWC":
+            engine = ""
+
+        op = core.CreateOperator(
+            "ConvTranspose",
+            ["X", "w", "b"] if use_bias else ["X", "w"],
+            ["Y"],
+            stride=stride,
+            kernel=kernel,
+            pad=pad,
+            adj=adj,
+            group=group,
+            order=order,
+            engine=engine,
+            shared_buffer=int(shared_buffer),
+            device_option=gc,
+        )
+
+        input_channels *= group
+        output_channels *= group
+
+        X = np.random.rand(
+            batch_size, size, size, input_channels).astype(np.float32) - 0.5
+        w = np.random.rand(
+            input_channels, kernel, kernel, int(output_channels / group)) \
+            .astype(np.float32) - 0.5
+        b = np.random.rand(output_channels).astype(np.float32) - 0.5
+        if order == "NCHW":
+            X = utils.NHWC2NCHW(X)
+            w = utils.NHWC2NCHW(w)
+
+        inputs = [X, w, b] if use_bias else [X, w]
+        self.assertDeviceChecks(dc, op, inputs, [0])
+        for i in range(len(inputs)):
+            self.assertGradientChecks(gc, op, inputs, i, [0])
+
+
 if __name__ == "__main__":
     import unittest
     unittest.main()