Adding a depthwise convolution kernel op (with label 'cudnn_grouped_convolution'...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 30 Apr 2018 13:59:23 +0000 (06:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 30 Apr 2018 14:01:55 +0000 (07:01 -0700)
PiperOrigin-RevId: 194780352

tensorflow/core/kernels/BUILD
tensorflow/core/kernels/conv_grad_filter_ops.cc
tensorflow/core/kernels/conv_grad_input_ops.cc
tensorflow/core/kernels/conv_grad_ops.cc
tensorflow/core/kernels/conv_ops.cc
tensorflow/core/kernels/depthwise_conv_grad_op.cc
tensorflow/core/kernels/depthwise_conv_op.cc
tensorflow/python/kernel_tests/depthwise_conv_op_test.py
tensorflow/stream_executor/cuda/cuda_dnn.cc
tensorflow/stream_executor/dnn.cc
tensorflow/stream_executor/dnn.h

index 6355f13..3fb03cd 100644 (file)
@@ -3299,7 +3299,10 @@ tf_kernel_library(
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:nn_ops_op_lib",
-    ] + if_cuda(["@cub_archive//:cub"]),
+    ] + if_cuda([
+        "@cub_archive//:cub",
+        "@local_config_cuda//cuda:cudnn",
+    ]),
 )
 
 tf_kernel_library(
@@ -3310,12 +3313,15 @@ tf_kernel_library(
     prefix = "depthwise_conv_grad_op",
     deps = [
         ":bounds_check",
+        ":conv_ops",
         ":ops_util",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:nn_ops_op_lib",
-    ],
+    ] + if_cuda([
+        "@local_config_cuda//cuda:cudnn",
+    ]),
 )
 
 cc_library(
index ef1e73e..aca7517 100644 (file)
@@ -96,7 +96,8 @@ template <typename T>
 struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
                   const Tensor& out_backprop, const Tensor& input,
-                  int row_stride, int col_stride, const Padding& padding,
+                  int row_dilation, int col_dilation, int row_stride,
+                  int col_stride, const Padding& padding,
                   Tensor* filter_backprop, TensorFormat data_format) {
     const CPUDevice& d = ctx->eigen_device<CPUDevice>();
     functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
@@ -275,7 +276,8 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
 #endif
 
     LaunchConv2DBackpropFilterOp<Device, T>()(
-        context, false, false, out_backprop, input, dims.spatial_dims[0].stride,
+        context, false, false, out_backprop, input,
+        /*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride,
         dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_);
   }
 
@@ -523,6 +525,11 @@ TF_CALL_float(REGISTER_CPU_KERNELS);
 TF_CALL_double(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
+
 // GPU definitions.
 #if GOOGLE_CUDA
 // The slow version (but compiles for GPU)
@@ -690,10 +697,15 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
     return;
   }
 
+  // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
+  // input depth, it's a depthwise convolution. More generally, if the filter
+  // in-depth divides but is smaller than the input depth, it is a grouped
+  // convolution.
+  bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
   bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
   if (!cudnn_disable_conv_1x1_optimization_ &&
       dims.spatial_dims[0].filter_size == 1 &&
-      dims.spatial_dims[1].filter_size == 1 &&
+      dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
       dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
       data_format == FORMAT_NHWC) {
     const uint64 m = dims.in_depth;
@@ -734,9 +746,10 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
                  dims.spatial_dims[0].input_size &&
              dims.spatial_dims[1].filter_size ==
                  dims.spatial_dims[1].input_size &&
-             padding == VALID && data_format == FORMAT_NHWC) {
-    // The input data and filter have the same height/width, so call cublas
-    // directly.
+             !is_grouped_convolution && padding == VALID &&
+             data_format == FORMAT_NHWC) {
+    // The input data and filter have the same height/width, and we are not
+    // using grouped convolution, so call cublas directly.
     const uint64 m = dims.spatial_dims[0].input_size *
                      dims.spatial_dims[1].input_size * dims.in_depth;
     const uint64 k = dims.batch_size;
@@ -802,15 +815,16 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
   se::dnn::FilterDescriptor filter_desc;
   filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
       .set_input_filter_width(dims.spatial_dims[1].filter_size)
-      .set_input_feature_map_count(dims.in_depth)
-      .set_output_feature_map_count(dims.out_depth);
+      .set_input_feature_map_count(filter_shape.dim_size(2))
+      .set_output_feature_map_count(filter_shape.dim_size(3));
   se::dnn::ConvolutionDescriptor conv_desc;
   conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
       .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
       .set_vertical_filter_stride(dims.spatial_dims[0].stride)
       .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
       .set_zero_padding_height(padding_rows / 2)
-      .set_zero_padding_width(padding_cols / 2);
+      .set_zero_padding_width(padding_cols / 2)
+      .set_group_count(dims.in_depth / filter_shape.dim_size(2));
 
   // NOTE(zhengxq):
   // cuDNN only supports the following layouts :
@@ -891,21 +905,22 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
   int device_id = stream->parent()->device_ordinal();
   DataType dtype = input.dtype();
   ConvParameters conv_parameters = {
-      dims.batch_size,                       // batch
-      dims.in_depth,                         // in_depths
-      {{input_desc.height(),                 // in_rows
-        input_desc.width()}},                // in_cols
-      dims.out_depth,                        // out_depths
-      {{dims.spatial_dims[0].filter_size,    // filter_rows
-        dims.spatial_dims[1].filter_size}},  // filter_cols
-      {{dims.spatial_dims[0].dilation,       // dilation_rows
-        dims.spatial_dims[1].dilation}},     // dilation_cols
-      {{dims.spatial_dims[0].stride,         // stride_rows
-        dims.spatial_dims[1].stride}},       // stride_cols
-      {{padding_rows,                        // padding_rows
-        padding_cols}},                      // padding_cols
-      dtype,                                 // tensor datatype
-      device_id,                             // device_id
+      dims.batch_size,                     // batch
+      dims.in_depth,                       // in_depths
+      {{input_desc.height(),               // in_rows
+        input_desc.width()}},              // in_cols
+      dims.out_depth,                      // out_depths
+      {{dims.spatial_dims[0].filter_size,  // filter_rows
+        dims.spatial_dims[1].filter_size,  // filter_cols
+        filter_shape.dim_size(2)}},        // filter_depth
+      {{dims.spatial_dims[0].dilation,     // dilation_rows
+        dims.spatial_dims[1].dilation}},   // dilation_cols
+      {{dims.spatial_dims[0].stride,       // stride_rows
+        dims.spatial_dims[1].stride}},     // stride_cols
+      {{padding_rows,                      // padding_rows
+        padding_cols}},                    // padding_cols
+      dtype,                               // tensor datatype
+      device_id,                           // device_id
   };
   AlgorithmConfig algorithm_config;
   if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
@@ -1019,9 +1034,9 @@ namespace functor {
       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
   extern template struct PadInput<GPUDevice, T, int, 4>;
 
-DECLARE_GPU_SPEC(double);
 DECLARE_GPU_SPEC(float);
 DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
 #undef DECLARE_GPU_SPEC
 }  // namespace functor
 
@@ -1040,6 +1055,12 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
                             .TypeConstraint<Eigen::half>("T")
                             .HostMemory("filter_sizes"),
                         Conv2DSlowBackpropFilterOp<GPUDevice, Eigen::half>);
+
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
+
 #endif  // GOOGLE_CUDA
 
 }  // namespace tensorflow
index 35f2676..63a775a 100644 (file)
@@ -101,8 +101,9 @@ template <typename T>
 struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
                   const Tensor& out_backprop, const Tensor& filter,
-                  int row_stride, int col_stride, const Padding& padding,
-                  Tensor* in_backprop, TensorFormat data_format) {
+                  int row_dilation, int col_dilation, int row_stride,
+                  int col_stride, const Padding& padding, Tensor* in_backprop,
+                  TensorFormat data_format) {
     const CPUDevice& d = ctx->eigen_device<CPUDevice>();
     functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
         d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
@@ -280,8 +281,8 @@ class Conv2DFastBackpropInputOp : public OpKernel {
 
     LaunchConv2DBackpropInputOp<Device, T>()(
         context, false, false, out_backprop, filter,
-        dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_,
-        in_backprop, data_format_);
+        /*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride,
+        dims.spatial_dims[1].stride, padding_, in_backprop, data_format_);
   }
 
  private:
@@ -595,6 +596,11 @@ TF_CALL_float(REGISTER_CPU_KERNELS);
 TF_CALL_double(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
+template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
+
 // GPU definitions.
 #if GOOGLE_CUDA
 // The slow version (but compiles for GPU)
@@ -761,8 +767,13 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
     return;
   }
 
+  // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
+  // input depth, it's a depthwise convolution. More generally, if the filter
+  // in-depth divides but is smaller than the input depth, it is a grouped
+  // convolution.
+  bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
   if (dims.spatial_dims[0].filter_size == 1 &&
-      dims.spatial_dims[1].filter_size == 1 &&
+      dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
       dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
       data_format == FORMAT_NHWC) {
     // 1x1 filter, so call cublas directly.
@@ -795,9 +806,10 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
                  dims.spatial_dims[0].input_size &&
              dims.spatial_dims[1].filter_size ==
                  dims.spatial_dims[1].input_size &&
-             padding == VALID && data_format == FORMAT_NHWC) {
-    // The input data and filter have the same height/width, so call cublas
-    // directly.
+             !is_grouped_convolution && padding == VALID &&
+             data_format == FORMAT_NHWC) {
+    // The input data and filter have the same height/width, and we are not
+    // using grouped convolution, so call cublas directly.
     const uint64 m = dims.batch_size;
     const uint64 k = dims.out_depth;
     const uint64 n = dims.spatial_dims[0].input_size *
@@ -856,15 +868,16 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
   se::dnn::FilterDescriptor filter_desc;
   filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
       .set_input_filter_width(dims.spatial_dims[1].filter_size)
-      .set_input_feature_map_count(dims.in_depth)
-      .set_output_feature_map_count(dims.out_depth);
+      .set_input_feature_map_count(filter_shape.dim_size(2))
+      .set_output_feature_map_count(filter_shape.dim_size(3));
   se::dnn::ConvolutionDescriptor conv_desc;
   conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
       .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
       .set_vertical_filter_stride(dims.spatial_dims[0].stride)
       .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
       .set_zero_padding_height(padding_rows / 2)
-      .set_zero_padding_width(padding_cols / 2);
+      .set_zero_padding_width(padding_cols / 2)
+      .set_group_count(dims.in_depth / filter_shape.dim_size(2));
 
   // NOTE(keveman):
   // cuDNN only supports the following layouts :
@@ -940,21 +953,22 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
   int device_id = stream->parent()->device_ordinal();
   DataType dtype = out_backprop.dtype();
   ConvParameters conv_parameters = {
-      dims.batch_size,                       // batch
-      dims.in_depth,                         // in_depths
-      {{input_desc.height(),                 // in_rows
-        input_desc.width()}},                // in_cols
-      dims.out_depth,                        // out_depths
-      {{dims.spatial_dims[0].filter_size,    // filter_rows
-        dims.spatial_dims[1].filter_size}},  // filter_cols
-      {{dims.spatial_dims[0].dilation,       // dilation_rows
-        dims.spatial_dims[1].dilation}},     // dilation_cols
-      {{dims.spatial_dims[0].stride,         // stride_rows
-        dims.spatial_dims[1].stride}},       // stride_cols
-      {{padding_rows,                        // padding_rows
-        padding_cols}},                      // padding_cols
-      dtype,                                 // tensor data type
-      device_id,                             // device_id
+      dims.batch_size,                     // batch
+      dims.in_depth,                       // in_depths
+      {{input_desc.height(),               // in_rows
+        input_desc.width()}},              // in_cols
+      dims.out_depth,                      // out_depths
+      {{dims.spatial_dims[0].filter_size,  // filter_rows
+        dims.spatial_dims[1].filter_size,  // filter_cols
+        filter_shape.dim_size(2)}},        // filter_depths
+      {{dims.spatial_dims[0].dilation,     // dilation_rows
+        dims.spatial_dims[1].dilation}},   // dilation_cols
+      {{dims.spatial_dims[0].stride,       // stride_rows
+        dims.spatial_dims[1].stride}},     // stride_cols
+      {{padding_rows,                      // padding_rows
+        padding_cols}},                    // padding_cols
+      dtype,                               // tensor data type
+      device_id,                           // device_id
   };
   AlgorithmConfig algorithm_config;
   if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
@@ -1092,9 +1106,9 @@ namespace functor {
       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
   extern template struct PadInput<GPUDevice, T, int, 4>;
 
-DECLARE_GPU_SPEC(double);
 DECLARE_GPU_SPEC(float);
 DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
 #undef DECLARE_GPU_SPEC
 }  // namespace functor
 
@@ -1113,6 +1127,12 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
                             .TypeConstraint<Eigen::half>("T")
                             .HostMemory("input_sizes"),
                         Conv2DSlowBackpropInputOp<GPUDevice, Eigen::half>);
+
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
+template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
+
 #endif  // GOOGLE_CUDA
 
 }  // namespace tensorflow
index 170ce31..5bf709a 100644 (file)
@@ -127,16 +127,17 @@ Status ConvBackpropComputeDimensionsV2(
   dims->in_depth = input_shape.dim_size(feature_dim);
   // The input and output feature dimensions are the second last and last
   // dimensions of the filter Tensor.
-  if (dims->in_depth != filter_shape.dim_size(num_dims - 2)) {
+  VLOG(2) << "input vs filter_in depth " << dims->in_depth << " "
+          << filter_shape.dim_size(num_dims - 2);
+  if (dims->in_depth % filter_shape.dim_size(num_dims - 2)) {
     return errors::InvalidArgument(
-        label, ": input and filter must have the same depth");
+        label, ": input depth must be evenly divisible by filter depth");
   }
   dims->out_depth = filter_shape.dim_size(num_dims - 1);
   if (dims->out_depth != out_backprop_shape.dim_size(feature_dim)) {
     return errors::InvalidArgument(
         label, ": filter and out_backprop must have the same out_depth");
   }
-
   dims->spatial_dims.resize(num_spatial_dims);
   for (int i = 0; i < num_spatial_dims; ++i) {
     int image_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
index c6d36b4..3b9886e 100644 (file)
@@ -18,10 +18,16 @@ limitations under the License.
 #define USE_EIGEN_TENSOR
 #define EIGEN_USE_THREADS
 
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif  // GOOGLE_CUDA
+
 #include "tensorflow/core/kernels/conv_ops.h"
+
 #include <string.h>
 #include <map>
 #include <vector>
+
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -32,9 +38,6 @@ limitations under the License.
 #include "tensorflow/core/kernels/conv_2d.h"
 #include "tensorflow/core/kernels/deep_conv2d.h"
 #include "tensorflow/core/kernels/ops_util.h"
-#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
-#include "tensorflow/core/kernels/xsmm_conv2d.h"
-#endif
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/lib/strings/numbers.h"
@@ -45,6 +48,10 @@ limitations under the License.
 #include "tensorflow/core/util/tensor_format.h"
 #include "tensorflow/core/util/use_cudnn.h"
 
+#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
+#include "tensorflow/core/kernels/xsmm_conv2d.h"
+#endif
+
 #if GOOGLE_CUDA
 #include "tensorflow/core/kernels/conv_ops_gpu.h"
 #include "tensorflow/core/platform/stream_executor.h"
@@ -123,6 +130,10 @@ struct LaunchConv2DOp<CPUDevice, T> {
                                 "NHWC tensor format for now."));
       return;
     }
+    const int64 in_depth = GetTensorDim(input, data_format, 'C');
+    OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
+                errors::Unimplemented("Generic conv implementation does not "
+                                      "support grouped convolutions for now."));
     LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
                                   row_dilation, col_dilation, padding, output,
                                   data_format);
@@ -324,12 +335,13 @@ class Conv2DOp : public BinaryOp<T> {
     }
 
     // The last dimension for input is in_depth. It must be the same as the
-    // filter's in_depth.
+    // filter's in_depth or be evenly divisible by filter's in_depth.
     const int64 in_depth = GetTensorDim(input, data_format_, 'C');
-    OP_REQUIRES(context, in_depth == filter.dim_size(2),
+    const int64 patch_depth = filter.dim_size(2);
+    OP_REQUIRES(context, in_depth % patch_depth == 0,
                 errors::InvalidArgument(
-                    "input and filter must have the same depth: ", in_depth,
-                    " vs ", filter.dim_size(2)));
+                    "input depth must be evenly divisible by filter depth: ",
+                    in_depth, " vs ", patch_depth));
 
     // The last dimension for filter is out_depth.
     const int out_depth = static_cast<int>(filter.dim_size(3));
@@ -386,6 +398,7 @@ class Conv2DOp : public BinaryOp<T> {
     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
 
     VLOG(2) << "Conv2D: in_depth = " << in_depth
+            << ", patch_depth = " << patch_depth
             << ", input_cols = " << input_cols
             << ", filter_cols = " << filter_cols
             << ", input_rows = " << input_rows
@@ -450,7 +463,9 @@ TF_CALL_double(REGISTER_CPU);
 #endif  // USE_GEMM_FOR_CONV
 
 // To be used inside depthwise_conv_op.cc.
+template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
 template struct LaunchConv2DOp<CPUDevice, float>;
+template struct LaunchConv2DOp<CPUDevice, double>;
 
 #if GOOGLE_CUDA
 int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
@@ -498,13 +513,24 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
   }
 
   Tensor input = input_param;
-
-  if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 &&
-      col_dilation == 1 && row_stride == 1 && col_stride == 1 &&
-      data_format == FORMAT_NHWC) {
+  const int64 in_batch = GetTensorDim(input, data_format, 'N');
+  int64 in_rows = GetTensorDim(input, data_format, 'H');
+  int64 in_cols = GetTensorDim(input, data_format, 'W');
+  const int64 in_depths = GetTensorDim(input, data_format, 'C');
+  const int64 patch_rows = filter.dim_size(0);
+  const int64 patch_cols = filter.dim_size(1);
+  const int64 patch_depths = filter.dim_size(2);
+
+  // If the filter in-depth (patch_depths) is 1 and smaller than the input
+  // depth, it's a depthwise convolution. More generally, if the filter in-depth
+  // divides but is smaller than the input depth, it is a grouped convolution.
+  bool is_grouped_convolution = patch_depths != in_depths;
+  if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
+      row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
+      col_stride == 1 && data_format == FORMAT_NHWC) {
     // 1x1 filter, so call cublas directly.
-    const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
-    const uint64 k = filter.dim_size(2);
+    const uint64 m = in_batch * in_rows * in_cols;
+    const uint64 k = patch_depths;
     const uint64 n = filter.dim_size(3);
 
     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
@@ -525,15 +551,14 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
                                       ", n=", n, ", k=", k));
     }
     return;
-  } else if (filter.dim_size(0) == input.dim_size(1) &&
-             filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
+  } else if (patch_rows == in_rows && patch_cols == in_cols &&
+             !is_grouped_convolution && row_dilation == 1 &&
              col_dilation == 1 && padding == VALID &&
              data_format == FORMAT_NHWC) {
     // The input data and filter have the same height/width, so call cublas
     // directly.
-    const uint64 m = input.dim_size(0);
-    const uint64 k =
-        filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
+    const uint64 m = in_batch;
+    const uint64 k = patch_rows * patch_cols * patch_depths;
     const uint64 n = filter.dim_size(3);
 
     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
@@ -558,16 +583,10 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
 
   int padding_rows = 0;
   int padding_cols = 0;
-  const int64 in_batch = GetTensorDim(input, data_format, 'N');
-  int64 in_rows = GetTensorDim(input, data_format, 'H');
-  int64 in_cols = GetTensorDim(input, data_format, 'W');
-  const int64 in_depths = GetTensorDim(input, data_format, 'C');
   const int64 out_batch = GetTensorDim(*output, data_format, 'N');
   const int64 out_rows = GetTensorDim(*output, data_format, 'H');
   const int64 out_cols = GetTensorDim(*output, data_format, 'W');
   const int64 out_depths = GetTensorDim(*output, data_format, 'C');
-  const int64 patch_rows = filter.dim_size(0);
-  const int64 patch_cols = filter.dim_size(1);
   if (padding == SAME) {
     // Total padding on rows and cols is
     // Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R
@@ -642,9 +661,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
       .set_feature_map_count(out_depths)
       .set_layout(se::dnn::DataLayout::kBatchDepthYX);
   se::dnn::FilterDescriptor filter_desc;
-  filter_desc.set_input_filter_height(filter.dim_size(0))
-      .set_input_filter_width(filter.dim_size(1))
-      .set_input_feature_map_count(filter.dim_size(2))
+  filter_desc.set_input_filter_height(patch_rows)
+      .set_input_filter_width(patch_cols)
+      .set_input_feature_map_count(patch_depths)
       .set_output_feature_map_count(filter.dim_size(3));
   se::dnn::ConvolutionDescriptor conv_desc;
   conv_desc.set_vertical_dilation_rate(row_dilation)
@@ -652,7 +671,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
       .set_vertical_filter_stride(row_stride)
       .set_horizontal_filter_stride(col_stride)
       .set_zero_padding_height(padding_rows / 2)
-      .set_zero_padding_width(padding_cols / 2);
+      .set_zero_padding_width(padding_cols / 2)
+      .set_group_count(in_depths / patch_depths);
 
   Tensor transformed_filter;
   OP_REQUIRES_OK(ctx, ctx->allocate_temp(
@@ -695,7 +715,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
         in_cols}},       // in_cols
       out_depths,        // out_depths
       {{patch_rows,      // filter_rows
-        patch_cols}},    // filter_cols
+        patch_cols,      // filter_cols
+        patch_depths}},  // filter_depths
       {{row_dilation,    // dilation_rows
         col_dilation}},  // dilation_cols
       {{row_stride,      // stride_rows
@@ -812,9 +833,9 @@ namespace functor {
       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format);     \
   extern template struct PadInput<GPUDevice, T, int, 4>
 
-DECLARE_GPU_SPEC(double);
 DECLARE_GPU_SPEC(float);
 DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
 #undef DECLARE_GPU_SPEC
 }  // namespace functor
 
@@ -830,7 +851,9 @@ REGISTER_KERNEL_BUILDER(
     Conv2DOp<GPUDevice, double>);
 
 // To be used inside depthwise_conv_op.cc.
-template class LaunchConv2DOp<GPUDevice, float>;
+template struct LaunchConv2DOp<GPUDevice, float>;
+template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DOp<GPUDevice, double>;
 
 #endif  // GOOGLE_CUDA
 
index 91a9587..7afa21a 100644 (file)
@@ -26,6 +26,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_types.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
 #include "tensorflow/core/kernels/depthwise_conv_op.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -33,9 +34,11 @@ limitations under the License.
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/use_cudnn.h"
 #include "tensorflow/core/util/work_sharder.h"
 
 #if GOOGLE_CUDA
+#include "cuda/include/cudnn.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #endif  // GOOGLE_CUDA
 
@@ -509,8 +512,19 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
   }
 }
 
+// Extern template instantiated in conv_grad_input_ops.cc.
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
+
 #if GOOGLE_CUDA
 
+// Extern template instantiated in conv_grad_input_ops.cc.
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
+
+// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
 extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice,
                                                           Eigen::half>;
 extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
@@ -548,6 +562,12 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
         errors::InvalidArgument("Current implementation does not yet support "
                                 "strides in the batch and depth dimensions."));
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+
+    // For in_depth == 1 and grouped convolutions.
+    use_cudnn_ = CanUseCudnn();
+    cudnn_use_autotune_ = CudnnUseAutotune();
+    use_cudnn_grouped_conv_ = false;
+    dtype_ = DataTypeToEnum<T>::value;
   }
 
   void Compute(OpKernelContext* context) override {
@@ -560,6 +580,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
             input_sizes.dims()));
     TensorShape input_shape;
     const int32* in_sizes_data = input_sizes.template flat<int32>().data();
+
     for (int i = 0; i < input_sizes.NumElements(); ++i) {
       OP_REQUIRES(context, in_sizes_data[i] >= 0,
                   errors::InvalidArgument("Dimension ", i,
@@ -568,27 +589,77 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
     }
     const TensorShape& filter_shape = filter.shape();
     EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput");
+
     Tensor* in_backprop = nullptr;
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
                                 {0}, 0, input_shape, &in_backprop));
-    auto out_backprop_ptr = out_backprop.template flat<T>().data();
-    auto filter_ptr = filter.template flat<T>().data();
-    auto in_backprop_ptr = in_backprop->template flat<T>().data();
+
     // If there is nothing to compute, return.
     if (input_shape.num_elements() == 0) {
       return;
     }
+
+    // If in_depth==1, this operation is just a standard convolution.
+    // Depthwise convolution is a special case of cuDNN's grouped convolution.
+    bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+    VLOG(2) << "DepthwiseConv2dNativeBackpropInput: "
+            << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+            << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+            << filter_cols << ", " << in_depth << ", " << depth_multiplier
+            << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+            << ", " << out_depth << "], stride = " << stride_
+            << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+            << ", Use cuDNN: " << use_cudnn;
+
+    if (use_cudnn) {
+      // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+      //
+      //                  | TensorFlow       | cuDNN
+      // --------------------------------------------------------------------
+      // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+      // filter_in_depth  | in_depth         | in_depth / group_count
+      //
+      // For depthwise convolution, we have group_count == in_depth.
+      int32 filter_in_depth = 1;
+      TensorShape shape =
+          TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+      Tensor reshaped_filter(/*type=*/dtype_);
+      OP_REQUIRES(
+          context, reshaped_filter.CopyFrom(filter, shape),
+          errors::Internal(
+              "Failed to reshape filter tensor for grouped convolution."));
+      // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+      // conv is supported.
+      launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop,
+                reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
+                stride_, stride_, padding_, in_backprop, data_format_);
+      return;
+    }
+
+    auto out_backprop_ptr = out_backprop.template flat<T>().data();
+    auto filter_ptr = filter.template flat<T>().data();
+    auto in_backprop_ptr = in_backprop->template flat<T>().data();
     LaunchDepthwiseConvBackpropInputOp<Device, T>()(
         context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr,
         data_format_);
   }
 
+ protected:
+  bool use_cudnn_grouped_conv_;
+
  private:
   std::vector<int32> strides_;
   Padding padding_;
   TensorFormat data_format_;
   int64 stride_;
 
+  // For in_depth == 1 and grouped convolutions.
+  LaunchConv2DBackpropInputOp<Device, T> launcher_;
+  bool use_cudnn_;
+  bool cudnn_use_autotune_;
+  DataType dtype_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp);
 };
 
@@ -597,23 +668,52 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
                               .Device(DEVICE_CPU)                    \
                               .TypeConstraint<T>("T"),               \
                           DepthwiseConv2dNativeBackpropInputOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
 TF_CALL_float(REGISTER_CPU_KERNEL);
+#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
 TF_CALL_double(REGISTER_CPU_KERNEL);
+#endif
 #undef REGISTER_CPU_KERNEL
 
 #if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput")
-                            .Device(DEVICE_GPU)
-                            .TypeConstraint<float>("T")
-                            .HostMemory("input_sizes"),
-                        DepthwiseConv2dNativeBackpropInputOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(
-    Name("DepthwiseConv2dNativeBackpropInput")
-        .Device(DEVICE_GPU)
-        .TypeConstraint<double>("T")
-        .HostMemory("input_sizes"),
-    DepthwiseConv2dNativeBackpropInputOp<GPUDevice, double>);
+
+#define REGISTER_GPU_KERNEL(T)                                       \
+  REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
+                              .Device(DEVICE_GPU)                    \
+                              .TypeConstraint<T>("T")                \
+                              .HostMemory("input_sizes"),            \
+                          DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvBackpropInputOp
+    : public DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T> {
+ public:
+  DepthwiseConv2dGroupedConvBackpropInputOp(OpKernelConstruction* context)
+      : DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>(context) {
+    this->use_cudnn_grouped_conv_ = true;
+  }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T)                              \
+  REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
+                              .Device(DEVICE_GPU)                    \
+                              .TypeConstraint<T>("T")                \
+                              .HostMemory("input_sizes")             \
+                              .Label("cudnn_grouped_convolution"),   \
+                          DepthwiseConv2dGroupedConvBackpropInputOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#undef REGISTER_GROUPED_CONV_KERNEL
+#endif  // CUDNN_VERSION
 #endif  // GOOGLE_CUDA
 
 // Kernels to compute the gradients of the filters for depthwise convolution.
@@ -885,8 +985,19 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
   }
 }
 
+// Extern template instantiated in conv_grad_filter_ops.cc.
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
+
 #if GOOGLE_CUDA
 
+// Extern template instantiated in conv_grad_filter_ops.cc.
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
+
+// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
 extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice,
                                                            Eigen::half>;
 extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
@@ -924,6 +1035,21 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
         errors::InvalidArgument("Current implementation does not yet support "
                                 "strides in the batch and depth dimensions."));
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+
+    // For in_depth == 1 and grouped convolutions.
+    use_cudnn_ = CanUseCudnn();
+    cudnn_use_autotune_ = CudnnUseAutotune();
+    use_cudnn_grouped_conv_ = false;
+
+    if (std::is_same<T, Eigen::half>::value) {
+      dtype_ = DT_HALF;
+    } else if (std::is_same<T, float>::value) {
+      dtype_ = DT_FLOAT;
+    } else if (std::is_same<T, double>::value) {
+      dtype_ = DT_DOUBLE;
+    } else {
+      LOG(ERROR) << "Only half, float, and double are supported.";
+    }
   }
 
   void Compute(OpKernelContext* context) override {
@@ -949,24 +1075,73 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
                                 {1}, 0, filter_shape, &filter_backprop));
 
-    auto out_backprop_ptr = out_backprop.template flat<T>().data();
-    auto input_ptr = input.template flat<T>().data();
-    auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
     // If there is nothing to compute, return.
     if (filter_shape.num_elements() == 0) {
       return;
     }
+
+    // If in_depth==1, this operation is just a standard convolution.
+    // Depthwise convolution is a special case of cuDNN's grouped convolution.
+    bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+    VLOG(2) << "DepthwiseConv2dNativeBackpropFilter: "
+            << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+            << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+            << filter_cols << ", " << in_depth << ", " << depth_multiplier
+            << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+            << ", " << out_depth << "], stride = " << stride_
+            << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+            << ", Use cuDNN: " << use_cudnn;
+
+    if (use_cudnn) {
+      // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+      //
+      //                  | TensorFlow       | cuDNN
+      // --------------------------------------------------------------------
+      // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+      // filter_in_depth  | in_depth         | in_depth / group_count
+      //
+      // For depthwise convolution, we have group_count == in_depth.
+      int32 filter_in_depth = 1;
+      TensorShape shape =
+          TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+      Tensor reshaped_filter(/*type=*/dtype_);
+      OP_REQUIRES(
+          context, reshaped_filter.CopyFrom(*filter_backprop, shape),
+          errors::Internal(
+              "Failed to reshape filter tensor for grouped convolution."));
+
+      // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+      // conv is supported.
+      launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
+                /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
+                padding_, &reshaped_filter, data_format_);
+      return;
+    }
+
+    auto out_backprop_ptr = out_backprop.template flat<T>().data();
+    auto input_ptr = input.template flat<T>().data();
+    auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
     LaunchDepthwiseConvBackpropFilterOp<Device, T>()(
         context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
         data_format_);
   }
 
+ protected:
+  bool use_cudnn_grouped_conv_;
+
  private:
   std::vector<int32> strides_;
   Padding padding_;
   TensorFormat data_format_;
   int64 stride_;
 
+  // For in_depth == 1 and grouped convolutions.
+  LaunchConv2DBackpropFilterOp<Device, T> launcher_;
+  bool use_cudnn_;
+  bool cudnn_use_autotune_;
+  DataType dtype_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropFilterOp);
 };
 
@@ -976,24 +1151,50 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
           .Device(DEVICE_CPU)                     \
           .TypeConstraint<T>("T"),                \
       DepthwiseConv2dNativeBackpropFilterOp<CPUDevice, T>);
+TF_CALL_half(REGISTER_CPU_KERNEL);
 TF_CALL_float(REGISTER_CPU_KERNEL);
+#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
 TF_CALL_double(REGISTER_CPU_KERNEL);
+#endif
 #undef REGISTER_CPU_KERNEL
 
 #if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(
-    Name("DepthwiseConv2dNativeBackpropFilter")
-        .Device(DEVICE_GPU)
-        .TypeConstraint<float>("T")
-        .HostMemory("filter_sizes"),
-    DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(
-    Name("DepthwiseConv2dNativeBackpropFilter")
-        .Device(DEVICE_GPU)
-        .TypeConstraint<double>("T")
-        .HostMemory("filter_sizes"),
-    DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, double>);
+#define REGISTER_GPU_KERNEL(T)                                        \
+  REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
+                              .Device(DEVICE_GPU)                     \
+                              .TypeConstraint<T>("T")                 \
+                              .HostMemory("filter_sizes"),            \
+                          DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvBackpropFilterOp
+    : public DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T> {
+ public:
+  DepthwiseConv2dGroupedConvBackpropFilterOp(OpKernelConstruction* context)
+      : DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>(context) {
+    this->use_cudnn_grouped_conv_ = true;
+  }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T)                               \
+  REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
+                              .Device(DEVICE_GPU)                     \
+                              .TypeConstraint<T>("T")                 \
+                              .HostMemory("filter_sizes")             \
+                              .Label("cudnn_grouped_convolution"),    \
+                          DepthwiseConv2dGroupedConvBackpropFilterOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#undef REGISTER_GROUPED_CONV_KERNEL
+#endif  // CUDNN_VERSION
 #endif  // GOOGLE_CUDA
 
 }  // namespace tensorflow
index 6dedb1a..d5f4a68 100644 (file)
@@ -39,6 +39,7 @@ limitations under the License.
 #include "tensorflow/core/util/work_sharder.h"
 
 #if GOOGLE_CUDA
+#include "cuda/include/cudnn.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #endif  // GOOGLE_CUDA
 
@@ -241,18 +242,22 @@ struct LaunchDepthwiseConvOp<CPUDevice, T> {
 };
 
 // Extern template instantiated in conv_ops.cc.
+extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
 extern template struct LaunchConv2DOp<CPUDevice, float>;
+extern template struct LaunchConv2DOp<CPUDevice, double>;
 
 #if GOOGLE_CUDA
 
+// Extern template instantiated in conv_ops.cc.
+extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DOp<GPUDevice, float>;
+extern template struct LaunchConv2DOp<GPUDevice, double>;
+
 // Extern template instantiated in depthwise_conv_op_gpu.cc.
 extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
 extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
 extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
 
-// Extern template instantiated in conv_ops.cc.
-extern template struct LaunchConv2DOp<GPUDevice, float>;
-
 #endif
 
 template <typename Device, typename T>
@@ -284,9 +289,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
                                 "strides in the batch and depth dimensions."));
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
 
-    // For special case when in_depth == 1.
+    // For in_depth == 1 and grouped convolutions.
     use_cudnn_ = CanUseCudnn();
     cudnn_use_autotune_ = CudnnUseAutotune();
+    use_cudnn_grouped_conv_ = false;
+    dtype_ = DataTypeToEnum<T>::value;
   }
 
   void Compute(OpKernelContext* context) override {
@@ -357,27 +364,47 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
     Tensor* output = nullptr;
     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
 
-    VLOG(2) << "DepthwiseConv2dNative: "
-            << " Input: [" << batch << ", " << input_rows << ", " << input_cols
-            << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
-            << filter_cols << ", " << in_depth << ", " << depth_multiplier
-            << "]; stride = " << stride_ << ", pad_rows = " << pad_rows
-            << ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
-            << out_rows << ", " << out_cols << ", " << out_depth << "]";
-
     // If there is nothing to compute, return.
     if (out_shape.num_elements() == 0) {
       return;
     }
 
-    // If in_depth==1, this operation is just a standard convolution, so
-    // invoke that op.
-    if (std::is_same<T, float>::value && in_depth == 1) {
+    // TODO(csigg): Have autotune decide if native is faster than cuDNN.
+    // If in_depth==1, this operation is just a standard convolution.
+    // Depthwise convolution is a special case of cuDNN's grouped convolution.
+    bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+    VLOG(2) << "DepthwiseConv2dNative: "
+            << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+            << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+            << filter_cols << ", " << in_depth << ", " << depth_multiplier
+            << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+            << ", " << out_depth << "], stride = " << stride_
+            << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+            << ", Use cuDNN: " << use_cudnn;
+
+    if (use_cudnn) {
+      // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+      //
+      //                  | TensorFlow       | cuDNN
+      // --------------------------------------------------------------------
+      // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+      // filter_in_depth  | in_depth         | in_depth / group_count
+      //
+      // For depthwise convolution, we have group_count == in_depth.
+      int32 filter_in_depth = 1;
+      TensorShape shape =
+          TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+      Tensor reshaped_filter(/*type=*/dtype_);
+      OP_REQUIRES(
+          context, reshaped_filter.CopyFrom(filter, shape),
+          errors::Internal(
+              "Failed to reshape filter tensor for grouped convolution."));
       // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
       // conv is supported.
-      launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
-                /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
-                padding_, output, data_format_);
+      launcher_(context, use_cudnn_, cudnn_use_autotune_, input,
+                reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
+                stride_, stride_, padding_, output, data_format_);
       return;
     }
 
@@ -403,6 +430,9 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
                                        output_ptr, data_format_);
   }
 
+ protected:
+  bool use_cudnn_grouped_conv_;
+
  private:
   std::vector<int32> strides_;
   Padding padding_;
@@ -410,10 +440,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
 
   int64 stride_;  // in height/width dimension.
 
-  // For the case in_depth == 1.
+  // For in_depth == 1 and grouped convolutions.
   LaunchConv2DOp<Device, T> launcher_;
   bool use_cudnn_;
   bool cudnn_use_autotune_;
+  DataType dtype_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
 };
@@ -421,7 +452,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
 #define REGISTER_CPU_KERNEL(T)                                                 \
   REGISTER_KERNEL_BUILDER(                                                     \
       Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      DepthwiseConv2dNativeOp<CPUDevice, T>);
+      DepthwiseConv2dNativeOp<CPUDevice, T>)
 
 TF_CALL_half(REGISTER_CPU_KERNEL);
 TF_CALL_float(REGISTER_CPU_KERNEL);
@@ -430,19 +461,38 @@ TF_CALL_double(REGISTER_CPU_KERNEL);
 #endif
 
 #if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
-                            .Device(DEVICE_GPU)
-                            .TypeConstraint<Eigen::half>("T"),
-                        DepthwiseConv2dNativeOp<GPUDevice, Eigen::half>);
-
-REGISTER_KERNEL_BUILDER(
-    Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<float>("T"),
-    DepthwiseConv2dNativeOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
-                            .Device(DEVICE_GPU)
-                            .TypeConstraint<double>("T"),
-                        DepthwiseConv2dNativeOp<GPUDevice, double>);
-#endif
+
+#define REGISTER_GPU_KERNEL(T)                                                 \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+      DepthwiseConv2dNativeOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvOp
+    : public DepthwiseConv2dNativeOp<GPUDevice, T> {
+ public:
+  DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context)
+      : DepthwiseConv2dNativeOp<GPUDevice, T>(context) {
+    this->use_cudnn_grouped_conv_ = true;
+  }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T)                            \
+  REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")            \
+                              .Device(DEVICE_GPU)                  \
+                              .TypeConstraint<T>("T")              \
+                              .Label("cudnn_grouped_convolution"), \
+                          DepthwiseConv2dGroupedConvOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#endif  // CUDNN_VERSION
+#endif  // GOOGLE_CUDA
 
 }  // namespace tensorflow
index f7ae1a0..659dc04 100644 (file)
@@ -22,12 +22,15 @@ import numpy as np
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import nn_impl
 from tensorflow.python.ops import nn_ops
 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
 
 
 def ConfigsToTest():
@@ -98,6 +101,7 @@ class DepthwiseConv2DTest(test.TestCase):
                     padding,
                     data_type,
                     use_gpu,
+                    grouped_conv=False,
                     data_format="NHWC"):
     """Verifies the output values of the convolution function.
 
@@ -110,25 +114,26 @@ class DepthwiseConv2DTest(test.TestCase):
       padding: Padding type.
       data_type: The data type to use.
       use_gpu: Whether to use GPU.
+      grouped_conv: Whether to use cuDNN 7's grouped convolution.
       data_format: The data_format of the input. "NHWC" or "NCHW".
     """
-    total_size_1 = 1
-    total_size_2 = 1
+    input_size = 1
+    filter_size = 1
     for s in tensor_in_sizes:
-      total_size_1 *= s
+      input_size *= s
     for s in filter_in_sizes:
-      total_size_2 *= s
+      filter_size *= s
     # Initializes the input and filter tensor with numbers incrementing from 1.
-    x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
-    x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
-    with self.test_session(use_gpu=use_gpu) as sess:
-      if data_type == dtypes.float16:
-        tolerance = 1e-5
-      elif data_type == dtypes.float32:
-        tolerance = 1e-5
-      else:
-        self.assertEqual(data_type, dtypes.float64)
-        tolerance = 1e-8
+    x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)]
+    x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
+    ops.reset_default_graph()
+    graph = ops.get_default_graph()
+    with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+      tolerance = {
+          dtypes.float16: 4e-2,
+          dtypes.float32: 1e-8,
+          dtypes.float64: 1e-13,
+      }[data_type]
 
       t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
       t1.set_shape(tensor_in_sizes)
@@ -142,25 +147,39 @@ class DepthwiseConv2DTest(test.TestCase):
         native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
         strides = [1, 1, stride, stride]
 
-      conv_native = nn_ops.depthwise_conv2d_native(
-          native_t1,
-          t2,
-          strides=strides,
-          data_format=data_format,
-          padding=padding)
+      with sess.graph._kernel_label_map({
+          "DepthwiseConv2dNative": "cudnn_grouped_convolution"
+      } if grouped_conv else {}):
+        conv_native = nn_ops.depthwise_conv2d_native(
+            native_t1,
+            t2,
+            strides=strides,
+            data_format=data_format,
+            padding=padding)
 
       if data_format == "NCHW":
         # Transpose back from NCHW to NHWC
         conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
 
+      try:
+        native_result = sess.run(conv_native)
+      except errors.InvalidArgumentError as e:
+        # Grouped convolution kernel is only registered for cuDNN 7. Silently
+        # return when we are running on an earlier version or without GPU.
+        if e.message.startswith(
+            "No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
+          tf_logging.warn("Skipping grouped convolution test")
+          return
+        raise e
+
       conv_interface = nn_impl.depthwise_conv2d(
           t1, t2, strides=[1, stride, stride, 1], padding=padding)
-
-      native_result = sess.run(conv_native)
       interface_result = sess.run(conv_interface)
 
-    print("data_type:", data_type, "use_gpu:", use_gpu, "max diff = ",
-          np.amax(np.absolute(native_result - interface_result)))
+    tf_logging.info(
+        "data_type: %r, use_gpu: %r, grouped_conv: %r, max diff = %f",
+        data_type, use_gpu, grouped_conv,
+        np.amax(np.absolute(native_result - interface_result)))
     self.assertArrayNear(
         np.ravel(native_result), np.ravel(interface_result), tolerance)
     self.assertShapeEqual(native_result, conv_native)
@@ -169,11 +188,22 @@ class DepthwiseConv2DTest(test.TestCase):
   def testDepthwiseConv2D(self):
     for index, (input_size, filter_size, _, stride,
                 padding) in enumerate(ConfigsToTest()):
-      print("Testing DepthwiseConv2D,", index, "th config:", input_size, "*",
-            filter_size, "stride:", stride, "padding:", padding)
+      tf_logging.info(
+          "Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
+          "%s", index, input_size, filter_size, stride, padding)
       for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+        tf_logging.info("Testing without grouped_conv")
         self._VerifyValues(
             input_size, filter_size, stride, padding, data_type, use_gpu=True)
+        tf_logging.info("Testing with grouped_conv")
+        self._VerifyValues(
+            input_size,
+            filter_size,
+            stride,
+            padding,
+            data_type,
+            use_gpu=True,
+            grouped_conv=True)
 
   def testDepthwiseConv2DFormat(self):
     if not test.is_gpu_available():
@@ -181,8 +211,9 @@ class DepthwiseConv2DTest(test.TestCase):
 
     for index, (input_size, filter_size, _, stride,
                 padding) in enumerate(ConfigsToTest()):
-      print("Testing DepthwiseConv2DFormat,", index, "th config:", input_size,
-            "*", filter_size, "stride:", stride, "padding:", padding)
+      tf_logging.info(
+          "Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
+          "padding: %s", index, input_size, filter_size, stride, padding)
       for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
         self._VerifyValues(
             input_size,
@@ -226,7 +257,7 @@ class DepthwiseConv2DTest(test.TestCase):
       conv = nn_ops.depthwise_conv2d_native(
           t1, t2, strides=[1, stride, stride, 1], padding=padding)
       value = sess.run(conv)
-    print("value = ", value)
+    tf_logging.info("value = %r", value)
     self.assertArrayNear(expected, np.ravel(value), 1e-5)
     self.assertShapeEqual(value, conv)
 
@@ -296,7 +327,7 @@ class DepthwiseConv2DTest(test.TestCase):
         expected=expected_output,
         use_gpu=True)
 
-  # Gradient checkers.This tests depthwise gradient computations for both
+  # Gradient checkers. This tests depthwise gradient computations for both
   # BackpropFilter and BackpropInput by comparing gradients computed by the
   # depthwise gradient ops with the gradients computed numerically (details can
   # be found in the compute_gradient_error().
@@ -310,6 +341,7 @@ class DepthwiseConv2DTest(test.TestCase):
                                 data_type,
                                 test_input,
                                 use_gpu,
+                                grouped_conv=False,
                                 data_format="NHWC"):
     input_size = 1
     for x in input_shape:
@@ -319,14 +351,14 @@ class DepthwiseConv2DTest(test.TestCase):
       filter_size *= x
     input_data = [x * 1.0 / input_size for x in range(0, input_size)]
     filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
-    with self.test_session(use_gpu=use_gpu):
-      if data_type == dtypes.float16:
-        tolerance = 0.002
-      elif data_type == dtypes.float32:
-        tolerance = 0.002
-      else:
-        self.assertEqual(data_type, dtypes.float64)
-        tolerance = 1e-8
+    ops.reset_default_graph()
+    graph = ops.get_default_graph()
+    with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+      tolerance = {
+          dtypes.float16: 2e-0,
+          dtypes.float32: 5e-4,
+          dtypes.float64: 1e-12,
+      }[data_type]
 
       input_tensor = constant_op.constant(
           input_data, shape=input_shape, dtype=data_type, name="input")
@@ -347,35 +379,49 @@ class DepthwiseConv2DTest(test.TestCase):
         ]
         strides = [1, 1, stride, stride]
 
-      depthwise_conv2d = nn_ops.depthwise_conv2d_native(
-          native_input,
-          filter_tensor,
-          strides,
-          padding,
-          data_format=data_format,
-          name="depthwise_conv2d")
+      with sess.graph._kernel_label_map({
+          "DepthwiseConv2dNative": "cudnn_grouped_convolution",
+          "DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution",
+          "DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution",
+      } if grouped_conv else {}):
+        depthwise_conv2d = nn_ops.depthwise_conv2d_native(
+            native_input,
+            filter_tensor,
+            strides,
+            padding,
+            data_format=data_format,
+            name="depthwise_conv2d")
 
       self.assertEqual(output_shape, depthwise_conv2d.get_shape())
-      if test_input:
-        err = gradient_checker.compute_gradient_error(
-            native_input, input_shape, depthwise_conv2d, output_shape)
-      else:
-        err = gradient_checker.compute_gradient_error(filter_tensor,
-                                                      filter_shape,
-                                                      depthwise_conv2d,
-                                                      output_shape)
-      print("data_type:", data_type, "use_gpu:", use_gpu, ", error = ", err)
+
+      try:
+        if test_input:
+          err = gradient_checker.compute_gradient_error(
+              native_input, input_shape, depthwise_conv2d, output_shape)
+        else:
+          err = gradient_checker.compute_gradient_error(
+              filter_tensor, filter_shape, depthwise_conv2d, output_shape)
+      except errors.InvalidArgumentError as e:
+        # Grouped convolution kernel is only registered for cuDNN 7. Silently
+        # return when we are running on an earlier version or without GPU.
+        if grouped_conv and e.message.startswith(
+            "No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
+          tf_logging.warn("Skipping grouped convolution test")
+          return
+        raise e
+
+      tf_logging.info(
+          "data_type: %r, use_gpu: %r, grouped_conv: %r, error = %f", data_type,
+          use_gpu, grouped_conv, err)
       self.assertLess(err, tolerance)
 
   def testDepthwiseConv2DInputGrad(self):
     for index, (input_size, filter_size, output_size, stride,
                 padding) in enumerate(CheckGradConfigsToTest()):
-      print("Testing DepthwiseConv2DInputGrad,", index, "th config:",
-            input_size, "*", filter_size, "stride:", stride, "padding:",
-            padding)
-      # Note: float16 test for DepthwiseConv2DInputGrad is not enabled,
-      # calculations are not very precise.
-      for data_type in [dtypes.float32, dtypes.float64]:
+      tf_logging.info(
+          "Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
+          "padding: %s", index, input_size, filter_size, stride, padding)
+      for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
         self._ConstructAndTestGradient(
             input_size,
             filter_size,
@@ -385,6 +431,16 @@ class DepthwiseConv2DTest(test.TestCase):
             data_type,
             test_input=True,
             use_gpu=True)
+        self._ConstructAndTestGradient(
+            input_size,
+            filter_size,
+            output_size,
+            stride,
+            padding,
+            data_type,
+            test_input=True,
+            use_gpu=True,
+            grouped_conv=True)
 
   def testDepthwiseConv2DInputGradFormat(self):
     if not test.is_gpu_available():
@@ -392,12 +448,11 @@ class DepthwiseConv2DTest(test.TestCase):
 
     for index, (input_size, filter_size, output_size, stride,
                 padding) in enumerate(CheckGradConfigsToTest()):
-      print("Testing DepthwiseConv2DInputGradFormat,", index, "th config:",
-            input_size, "*", filter_size, "stride:", stride, "padding:",
-            padding)
-      # Note: float16 test for DepthwiseConv2DInputGradFormat is not enabled,
-      # calculations are not very precise.
-      for data_type in [dtypes.float32, dtypes.float64]:
+      tf_logging.info(
+          "Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
+          "stride: %d, padding: %s", index, input_size, filter_size, stride,
+          padding)
+      for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
         self._ConstructAndTestGradient(
             input_size,
             filter_size,
@@ -412,12 +467,10 @@ class DepthwiseConv2DTest(test.TestCase):
   def testDepthwiseConv2DFilterGrad(self):
     for index, (input_size, filter_size, output_size, stride,
                 padding) in enumerate(CheckGradConfigsToTest()):
-      print("Testing DepthwiseConv2DFilterGrad,", index, "th config:",
-            input_size, "*", filter_size, "stride:", stride, "padding:",
-            padding)
-      # Note: float16 test for DepthwiseConv2DFilterGrad is not enabled,
-      # calculations are not very precise.
-      for data_type in [dtypes.float32, dtypes.float64]:
+      tf_logging.info(
+          "Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
+          "%d, padding: %s", index, input_size, filter_size, stride, padding)
+      for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
         self._ConstructAndTestGradient(
             input_size,
             filter_size,
@@ -434,12 +487,11 @@ class DepthwiseConv2DTest(test.TestCase):
 
     for index, (input_size, filter_size, output_size, stride,
                 padding) in enumerate(CheckGradConfigsToTest()):
-      print("Testing DepthwiseConv2DFilterGradFormat,", index, "th config:",
-            input_size, "*", filter_size, "stride:", stride, "padding:",
-            padding)
-      # Note: float16 test for DepthwiseConv2DFilterGradFormat is not enabled,
-      # calculations are not very precise.
-      for data_type in [dtypes.float32, dtypes.float64]:
+      tf_logging.info(
+          "Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
+          "stride: %d, padding: %s", index, input_size, filter_size, stride,
+          padding)
+      for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
         self._ConstructAndTestGradient(
             input_size,
             filter_size,
@@ -494,9 +546,10 @@ class DepthwiseConv2DTest(test.TestCase):
   def testDepthwiseConv2DInputGradCompare(self):
     for index, (input_size, filter_size, output_size, stride,
                 padding) in enumerate(ConfigsToTest()):
-      print("Testing DepthwiseConv2DInputGradCompare,", index, "th config:",
-            input_size, "*", filter_size, "stride:", stride, "padding:",
-            padding)
+      tf_logging.info(
+          "Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
+          "stride: %d, padding: %s", index, input_size, filter_size, stride,
+          padding)
       self._CompareBackpropInputFloat(input_size, filter_size, output_size,
                                       stride, padding)
       self._CompareBackpropInputDouble(input_size, filter_size, output_size,
@@ -545,9 +598,10 @@ class DepthwiseConv2DTest(test.TestCase):
   def testDepthwiseConv2DFilterGradCompare(self):
     for index, (input_size, filter_size, output_size, stride,
                 padding) in enumerate(ConfigsToTest()):
-      print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
-            input_size, "*", filter_size, "stride:", stride, "padding:",
-            padding)
+      tf_logging.info(
+          "Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
+          "stride: %d, padding: %s", index, input_size, filter_size, stride,
+          padding)
       self._CompareBackpropFilterFloat(input_size, filter_size, output_size,
                                        stride, padding)
       self._CompareBackpropFilterDouble(input_size, filter_size, output_size,
index 42a77aa..773cac2 100644 (file)
@@ -337,7 +337,9 @@ CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
 #if CUDNN_VERSION >= 7000
 #define CUDNN_DNN_ROUTINE_EACH_R7(__macro)                    \
   __macro(cudnnSetConvolutionMathType)                        \
-  __macro(cudnnSetRNNMatrixMathType)
+  __macro(cudnnSetRNNMatrixMathType)                          \
+  __macro(cudnnSetConvolutionGroupCount)                      \
+  __macro(cudnnGetConvolutionGroupCount)
 
 // clang-format on
 CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP)
@@ -779,6 +781,20 @@ class ScopedConvolutionDescriptor {
     // NOTE(benbarsdell): This only applies if tensor op math is enabled
     //                      and algo selection is set to Default.
     this->set_use_tensor_op_math(true);
+
+#if CUDNN_MAJOR >= 7
+    VLOG(2) << "Requesting grouped convolution: "
+            << convolution_descriptor.group_count();
+    status = wrap::cudnnSetConvolutionGroupCount(
+        parent_, handle_, convolution_descriptor.group_count());
+    if (status != CUDNN_STATUS_SUCCESS) {
+      LOG(FATAL) << "could not set cudnn convolution group count: "
+                 << ToString(status);
+    }
+#else
+    CHECK_EQ(convolution_descriptor.group_count(), 1)
+        << "Requested grouped convolution for cuDNN version < 7";
+#endif
   }
 
   void set_use_tensor_op_math(bool use_tensor_op_math) {
index 031c82d..eed93ef 100644 (file)
@@ -434,6 +434,7 @@ ConvolutionDescriptor::ConvolutionDescriptor(int ndims)
       filter_strides_(ndims, 1),
       dilation_rates_(ndims, 1),
       pad_alignment_(PadAlignment::kDefault),
+      group_count_(1),
       ndims_(ndims) {}
 
 ConvolutionDescriptor::ConvolutionDescriptor()
index 0c2e083..18606eb 100644 (file)
@@ -543,6 +543,10 @@ class ConvolutionDescriptor {
     pad_alignment_ = pad_alignment;
     return *this;
   }
+  ConvolutionDescriptor& set_group_count(int group_count) {
+    group_count_ = group_count;
+    return *this;
+  }
   int64 zero_padding_height() const {
     return GetDim(zero_padding_, DimIndex::Y);
   }
@@ -566,6 +570,7 @@ class ConvolutionDescriptor {
   int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); }
   int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); }
   PadAlignment pad_alignment() const { return pad_alignment_; }
+  int group_count() const { return group_count_; }
   int ndims() const { return ndims_; }
 
   std::vector<int64> strides() const { return filter_strides_; }
@@ -578,6 +583,7 @@ class ConvolutionDescriptor {
   std::vector<int64> filter_strides_;
   std::vector<int64> dilation_rates_;
   PadAlignment pad_alignment_;
+  int group_count_;
   int ndims_;
   // TODO(leary) cudnn provides these fields, but need to characterize what
   // their effect is -- they may be boolean rather than integral.