From a5a51ad3a1200e2e5ef46c140bab717422e41ca2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Apr 2018 06:59:23 -0700 Subject: [PATCH] Adding a depthwise convolution kernel op (with label 'cudnn_grouped_convolution') which forwards to cuDNN grouped convolutions. PiperOrigin-RevId: 194780352 --- tensorflow/core/kernels/BUILD | 10 +- tensorflow/core/kernels/conv_grad_filter_ops.cc | 71 ++++-- tensorflow/core/kernels/conv_grad_input_ops.cc | 74 +++--- tensorflow/core/kernels/conv_grad_ops.cc | 7 +- tensorflow/core/kernels/conv_ops.cc | 85 ++++--- tensorflow/core/kernels/depthwise_conv_grad_op.cc | 263 ++++++++++++++++++--- tensorflow/core/kernels/depthwise_conv_op.cc | 118 ++++++--- .../python/kernel_tests/depthwise_conv_op_test.py | 222 ++++++++++------- tensorflow/stream_executor/cuda/cuda_dnn.cc | 18 +- tensorflow/stream_executor/dnn.cc | 1 + tensorflow/stream_executor/dnn.h | 6 + 11 files changed, 637 insertions(+), 238 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 6355f13..3fb03cd 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3299,7 +3299,10 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:nn_ops_op_lib", - ] + if_cuda(["@cub_archive//:cub"]), + ] + if_cuda([ + "@cub_archive//:cub", + "@local_config_cuda//cuda:cudnn", + ]), ) tf_kernel_library( @@ -3310,12 +3313,15 @@ tf_kernel_library( prefix = "depthwise_conv_grad_op", deps = [ ":bounds_check", + ":conv_ops", ":ops_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:nn_ops_op_lib", - ], + ] + if_cuda([ + "@local_config_cuda//cuda:cudnn", + ]), ) cc_library( diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index ef1e73e..aca7517 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -96,7 +96,8 @@ template struct LaunchConv2DBackpropFilterOp { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& out_backprop, const Tensor& input, - int row_stride, int col_stride, const Padding& padding, + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, Tensor* filter_backprop, TensorFormat data_format) { const CPUDevice& d = ctx->eigen_device(); functor::SpatialConvolutionBackwardFilter()( @@ -275,7 +276,8 @@ class Conv2DFastBackpropFilterOp : public OpKernel { #endif LaunchConv2DBackpropFilterOp()( - context, false, false, out_backprop, input, dims.spatial_dims[0].stride, + context, false, false, out_backprop, input, + /*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_); } @@ -523,6 +525,11 @@ TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS +// To be used inside depthwise_conv_grad_op.cc. +template struct LaunchConv2DBackpropFilterOp; +template struct LaunchConv2DBackpropFilterOp; +template struct LaunchConv2DBackpropFilterOp; + // GPU definitions. #if GOOGLE_CUDA // The slow version (but compiles for GPU) @@ -690,10 +697,15 @@ void LaunchConv2DBackpropFilterOp::operator()( return; } + // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the + // input depth, it's a depthwise convolution. More generally, if the filter + // in-depth divides but is smaller than the input depth, it is a grouped + // convolution. + bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth; bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization(); if (!cudnn_disable_conv_1x1_optimization_ && dims.spatial_dims[0].filter_size == 1 && - dims.spatial_dims[1].filter_size == 1 && + dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution && dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && data_format == FORMAT_NHWC) { const uint64 m = dims.in_depth; @@ -734,9 +746,10 @@ void LaunchConv2DBackpropFilterOp::operator()( dims.spatial_dims[0].input_size && dims.spatial_dims[1].filter_size == dims.spatial_dims[1].input_size && - padding == VALID && data_format == FORMAT_NHWC) { - // The input data and filter have the same height/width, so call cublas - // directly. + !is_grouped_convolution && padding == VALID && + data_format == FORMAT_NHWC) { + // The input data and filter have the same height/width, and we are not + // using grouped convolution, so call cublas directly. const uint64 m = dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size * dims.in_depth; const uint64 k = dims.batch_size; @@ -802,15 +815,16 @@ void LaunchConv2DBackpropFilterOp::operator()( se::dnn::FilterDescriptor filter_desc; filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) .set_input_filter_width(dims.spatial_dims[1].filter_size) - .set_input_feature_map_count(dims.in_depth) - .set_output_feature_map_count(dims.out_depth); + .set_input_feature_map_count(filter_shape.dim_size(2)) + .set_output_feature_map_count(filter_shape.dim_size(3)); se::dnn::ConvolutionDescriptor conv_desc; conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation) .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation) .set_vertical_filter_stride(dims.spatial_dims[0].stride) .set_horizontal_filter_stride(dims.spatial_dims[1].stride) .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); + .set_zero_padding_width(padding_cols / 2) + .set_group_count(dims.in_depth / filter_shape.dim_size(2)); // NOTE(zhengxq): // cuDNN only supports the following layouts : @@ -891,21 +905,22 @@ void LaunchConv2DBackpropFilterOp::operator()( int device_id = stream->parent()->device_ordinal(); DataType dtype = input.dtype(); ConvParameters conv_parameters = { - dims.batch_size, // batch - dims.in_depth, // in_depths - {{input_desc.height(), // in_rows - input_desc.width()}}, // in_cols - dims.out_depth, // out_depths - {{dims.spatial_dims[0].filter_size, // filter_rows - dims.spatial_dims[1].filter_size}}, // filter_cols - {{dims.spatial_dims[0].dilation, // dilation_rows - dims.spatial_dims[1].dilation}}, // dilation_cols - {{dims.spatial_dims[0].stride, // stride_rows - dims.spatial_dims[1].stride}}, // stride_cols - {{padding_rows, // padding_rows - padding_cols}}, // padding_cols - dtype, // tensor datatype - device_id, // device_id + dims.batch_size, // batch + dims.in_depth, // in_depths + {{input_desc.height(), // in_rows + input_desc.width()}}, // in_cols + dims.out_depth, // out_depths + {{dims.spatial_dims[0].filter_size, // filter_rows + dims.spatial_dims[1].filter_size, // filter_cols + filter_shape.dim_size(2)}}, // filter_depth + {{dims.spatial_dims[0].dilation, // dilation_rows + dims.spatial_dims[1].dilation}}, // dilation_cols + {{dims.spatial_dims[0].stride, // stride_rows + dims.spatial_dims[1].stride}}, // stride_cols + {{padding_rows, // padding_rows + padding_cols}}, // padding_cols + dtype, // tensor datatype + device_id, // device_id }; AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find( @@ -1019,9 +1034,9 @@ namespace functor { typename TTypes::Tensor out, TensorFormat data_format); \ extern template struct PadInput; -DECLARE_GPU_SPEC(double); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(Eigen::half); +DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor @@ -1040,6 +1055,12 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") .TypeConstraint("T") .HostMemory("filter_sizes"), Conv2DSlowBackpropFilterOp); + +// To be used inside depthwise_conv_grad_op.cc. +template struct LaunchConv2DBackpropFilterOp; +template struct LaunchConv2DBackpropFilterOp; +template struct LaunchConv2DBackpropFilterOp; + #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 35f2676..63a775a 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -101,8 +101,9 @@ template struct LaunchConv2DBackpropInputOp { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& out_backprop, const Tensor& filter, - int row_stride, int col_stride, const Padding& padding, - Tensor* in_backprop, TensorFormat data_format) { + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, Tensor* in_backprop, + TensorFormat data_format) { const CPUDevice& d = ctx->eigen_device(); functor::SpatialConvolutionBackwardInput()( d, in_backprop->tensor(), filter.tensor(), @@ -280,8 +281,8 @@ class Conv2DFastBackpropInputOp : public OpKernel { LaunchConv2DBackpropInputOp()( context, false, false, out_backprop, filter, - dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_, - in_backprop, data_format_); + /*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride, + dims.spatial_dims[1].stride, padding_, in_backprop, data_format_); } private: @@ -595,6 +596,11 @@ TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS +// To be used inside depthwise_conv_grad_op.cc. +template struct LaunchConv2DBackpropInputOp; +template struct LaunchConv2DBackpropInputOp; +template struct LaunchConv2DBackpropInputOp; + // GPU definitions. #if GOOGLE_CUDA // The slow version (but compiles for GPU) @@ -761,8 +767,13 @@ void LaunchConv2DBackpropInputOp::operator()( return; } + // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the + // input depth, it's a depthwise convolution. More generally, if the filter + // in-depth divides but is smaller than the input depth, it is a grouped + // convolution. + bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth; if (dims.spatial_dims[0].filter_size == 1 && - dims.spatial_dims[1].filter_size == 1 && + dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution && dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && data_format == FORMAT_NHWC) { // 1x1 filter, so call cublas directly. @@ -795,9 +806,10 @@ void LaunchConv2DBackpropInputOp::operator()( dims.spatial_dims[0].input_size && dims.spatial_dims[1].filter_size == dims.spatial_dims[1].input_size && - padding == VALID && data_format == FORMAT_NHWC) { - // The input data and filter have the same height/width, so call cublas - // directly. + !is_grouped_convolution && padding == VALID && + data_format == FORMAT_NHWC) { + // The input data and filter have the same height/width, and we are not + // using grouped convolution, so call cublas directly. const uint64 m = dims.batch_size; const uint64 k = dims.out_depth; const uint64 n = dims.spatial_dims[0].input_size * @@ -856,15 +868,16 @@ void LaunchConv2DBackpropInputOp::operator()( se::dnn::FilterDescriptor filter_desc; filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) .set_input_filter_width(dims.spatial_dims[1].filter_size) - .set_input_feature_map_count(dims.in_depth) - .set_output_feature_map_count(dims.out_depth); + .set_input_feature_map_count(filter_shape.dim_size(2)) + .set_output_feature_map_count(filter_shape.dim_size(3)); se::dnn::ConvolutionDescriptor conv_desc; conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation) .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation) .set_vertical_filter_stride(dims.spatial_dims[0].stride) .set_horizontal_filter_stride(dims.spatial_dims[1].stride) .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); + .set_zero_padding_width(padding_cols / 2) + .set_group_count(dims.in_depth / filter_shape.dim_size(2)); // NOTE(keveman): // cuDNN only supports the following layouts : @@ -940,21 +953,22 @@ void LaunchConv2DBackpropInputOp::operator()( int device_id = stream->parent()->device_ordinal(); DataType dtype = out_backprop.dtype(); ConvParameters conv_parameters = { - dims.batch_size, // batch - dims.in_depth, // in_depths - {{input_desc.height(), // in_rows - input_desc.width()}}, // in_cols - dims.out_depth, // out_depths - {{dims.spatial_dims[0].filter_size, // filter_rows - dims.spatial_dims[1].filter_size}}, // filter_cols - {{dims.spatial_dims[0].dilation, // dilation_rows - dims.spatial_dims[1].dilation}}, // dilation_cols - {{dims.spatial_dims[0].stride, // stride_rows - dims.spatial_dims[1].stride}}, // stride_cols - {{padding_rows, // padding_rows - padding_cols}}, // padding_cols - dtype, // tensor data type - device_id, // device_id + dims.batch_size, // batch + dims.in_depth, // in_depths + {{input_desc.height(), // in_rows + input_desc.width()}}, // in_cols + dims.out_depth, // out_depths + {{dims.spatial_dims[0].filter_size, // filter_rows + dims.spatial_dims[1].filter_size, // filter_cols + filter_shape.dim_size(2)}}, // filter_depths + {{dims.spatial_dims[0].dilation, // dilation_rows + dims.spatial_dims[1].dilation}}, // dilation_cols + {{dims.spatial_dims[0].stride, // stride_rows + dims.spatial_dims[1].stride}}, // stride_cols + {{padding_rows, // padding_rows + padding_cols}}, // padding_cols + dtype, // tensor data type + device_id, // device_id }; AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find( @@ -1092,9 +1106,9 @@ namespace functor { typename TTypes::Tensor out, TensorFormat data_format); \ extern template struct PadInput; -DECLARE_GPU_SPEC(double); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(Eigen::half); +DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor @@ -1113,6 +1127,12 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") .TypeConstraint("T") .HostMemory("input_sizes"), Conv2DSlowBackpropInputOp); + +// To be used inside depthwise_conv_grad_op.cc. +template struct LaunchConv2DBackpropInputOp; +template struct LaunchConv2DBackpropInputOp; +template struct LaunchConv2DBackpropInputOp; + #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index 170ce31..5bf709a 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -127,16 +127,17 @@ Status ConvBackpropComputeDimensionsV2( dims->in_depth = input_shape.dim_size(feature_dim); // The input and output feature dimensions are the second last and last // dimensions of the filter Tensor. - if (dims->in_depth != filter_shape.dim_size(num_dims - 2)) { + VLOG(2) << "input vs filter_in depth " << dims->in_depth << " " + << filter_shape.dim_size(num_dims - 2); + if (dims->in_depth % filter_shape.dim_size(num_dims - 2)) { return errors::InvalidArgument( - label, ": input and filter must have the same depth"); + label, ": input depth must be evenly divisible by filter depth"); } dims->out_depth = filter_shape.dim_size(num_dims - 1); if (dims->out_depth != out_backprop_shape.dim_size(feature_dim)) { return errors::InvalidArgument( label, ": filter and out_backprop must have the same out_depth"); } - dims->spatial_dims.resize(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { int image_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index c6d36b4..3b9886e 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -18,10 +18,16 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + #include "tensorflow/core/kernels/conv_ops.h" + #include #include #include + #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -32,9 +38,6 @@ limitations under the License. #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/deep_conv2d.h" #include "tensorflow/core/kernels/ops_util.h" -#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS -#include "tensorflow/core/kernels/xsmm_conv2d.h" -#endif #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -45,6 +48,10 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS +#include "tensorflow/core/kernels/xsmm_conv2d.h" +#endif + #if GOOGLE_CUDA #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" @@ -123,6 +130,10 @@ struct LaunchConv2DOp { "NHWC tensor format for now.")); return; } + const int64 in_depth = GetTensorDim(input, data_format, 'C'); + OP_REQUIRES(ctx, in_depth == filter.dim_size(2), + errors::Unimplemented("Generic conv implementation does not " + "support grouped convolutions for now.")); LaunchGeneric()(ctx, input, filter, row_stride, col_stride, row_dilation, col_dilation, padding, output, data_format); @@ -324,12 +335,13 @@ class Conv2DOp : public BinaryOp { } // The last dimension for input is in_depth. It must be the same as the - // filter's in_depth. + // filter's in_depth or be evenly divisible by filter's in_depth. const int64 in_depth = GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES(context, in_depth == filter.dim_size(2), + const int64 patch_depth = filter.dim_size(2); + OP_REQUIRES(context, in_depth % patch_depth == 0, errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", filter.dim_size(2))); + "input depth must be evenly divisible by filter depth: ", + in_depth, " vs ", patch_depth)); // The last dimension for filter is out_depth. const int out_depth = static_cast(filter.dim_size(3)); @@ -386,6 +398,7 @@ class Conv2DOp : public BinaryOp { OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); VLOG(2) << "Conv2D: in_depth = " << in_depth + << ", patch_depth = " << patch_depth << ", input_cols = " << input_cols << ", filter_cols = " << filter_cols << ", input_rows = " << input_rows @@ -450,7 +463,9 @@ TF_CALL_double(REGISTER_CPU); #endif // USE_GEMM_FOR_CONV // To be used inside depthwise_conv_op.cc. +template struct LaunchConv2DOp; template struct LaunchConv2DOp; +template struct LaunchConv2DOp; #if GOOGLE_CUDA int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb, @@ -498,13 +513,24 @@ void LaunchConv2DOp::operator()( } Tensor input = input_param; - - if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 && - col_dilation == 1 && row_stride == 1 && col_stride == 1 && - data_format == FORMAT_NHWC) { + const int64 in_batch = GetTensorDim(input, data_format, 'N'); + int64 in_rows = GetTensorDim(input, data_format, 'H'); + int64 in_cols = GetTensorDim(input, data_format, 'W'); + const int64 in_depths = GetTensorDim(input, data_format, 'C'); + const int64 patch_rows = filter.dim_size(0); + const int64 patch_cols = filter.dim_size(1); + const int64 patch_depths = filter.dim_size(2); + + // If the filter in-depth (patch_depths) is 1 and smaller than the input + // depth, it's a depthwise convolution. More generally, if the filter in-depth + // divides but is smaller than the input depth, it is a grouped convolution. + bool is_grouped_convolution = patch_depths != in_depths; + if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution && + row_dilation == 1 && col_dilation == 1 && row_stride == 1 && + col_stride == 1 && data_format == FORMAT_NHWC) { // 1x1 filter, so call cublas directly. - const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2); - const uint64 k = filter.dim_size(2); + const uint64 m = in_batch * in_rows * in_cols; + const uint64 k = patch_depths; const uint64 n = filter.dim_size(3); auto a_ptr = AsDeviceMemory(input.template flat().data(), @@ -525,15 +551,14 @@ void LaunchConv2DOp::operator()( ", n=", n, ", k=", k)); } return; - } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 && + } else if (patch_rows == in_rows && patch_cols == in_cols && + !is_grouped_convolution && row_dilation == 1 && col_dilation == 1 && padding == VALID && data_format == FORMAT_NHWC) { // The input data and filter have the same height/width, so call cublas // directly. - const uint64 m = input.dim_size(0); - const uint64 k = - filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2); + const uint64 m = in_batch; + const uint64 k = patch_rows * patch_cols * patch_depths; const uint64 n = filter.dim_size(3); auto a_ptr = AsDeviceMemory(input.template flat().data(), @@ -558,16 +583,10 @@ void LaunchConv2DOp::operator()( int padding_rows = 0; int padding_cols = 0; - const int64 in_batch = GetTensorDim(input, data_format, 'N'); - int64 in_rows = GetTensorDim(input, data_format, 'H'); - int64 in_cols = GetTensorDim(input, data_format, 'W'); - const int64 in_depths = GetTensorDim(input, data_format, 'C'); const int64 out_batch = GetTensorDim(*output, data_format, 'N'); const int64 out_rows = GetTensorDim(*output, data_format, 'H'); const int64 out_cols = GetTensorDim(*output, data_format, 'W'); const int64 out_depths = GetTensorDim(*output, data_format, 'C'); - const int64 patch_rows = filter.dim_size(0); - const int64 patch_cols = filter.dim_size(1); if (padding == SAME) { // Total padding on rows and cols is // Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R @@ -642,9 +661,9 @@ void LaunchConv2DOp::operator()( .set_feature_map_count(out_depths) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(filter.dim_size(0)) - .set_input_filter_width(filter.dim_size(1)) - .set_input_feature_map_count(filter.dim_size(2)) + filter_desc.set_input_filter_height(patch_rows) + .set_input_filter_width(patch_cols) + .set_input_feature_map_count(patch_depths) .set_output_feature_map_count(filter.dim_size(3)); se::dnn::ConvolutionDescriptor conv_desc; conv_desc.set_vertical_dilation_rate(row_dilation) @@ -652,7 +671,8 @@ void LaunchConv2DOp::operator()( .set_vertical_filter_stride(row_stride) .set_horizontal_filter_stride(col_stride) .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); + .set_zero_padding_width(padding_cols / 2) + .set_group_count(in_depths / patch_depths); Tensor transformed_filter; OP_REQUIRES_OK(ctx, ctx->allocate_temp( @@ -695,7 +715,8 @@ void LaunchConv2DOp::operator()( in_cols}}, // in_cols out_depths, // out_depths {{patch_rows, // filter_rows - patch_cols}}, // filter_cols + patch_cols, // filter_cols + patch_depths}}, // filter_depths {{row_dilation, // dilation_rows col_dilation}}, // dilation_cols {{row_stride, // stride_rows @@ -812,9 +833,9 @@ namespace functor { typename TTypes::Tensor out, TensorFormat data_format); \ extern template struct PadInput -DECLARE_GPU_SPEC(double); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(Eigen::half); +DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor @@ -830,7 +851,9 @@ REGISTER_KERNEL_BUILDER( Conv2DOp); // To be used inside depthwise_conv_op.cc. -template class LaunchConv2DOp; +template struct LaunchConv2DOp; +template struct LaunchConv2DOp; +template struct LaunchConv2DOp; #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index 91a9587..7afa21a 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/depthwise_conv_op.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" @@ -33,9 +34,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA +#include "cuda/include/cudnn.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -509,8 +512,19 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args, } } +// Extern template instantiated in conv_grad_input_ops.cc. +extern template struct LaunchConv2DBackpropInputOp; +extern template struct LaunchConv2DBackpropInputOp; +extern template struct LaunchConv2DBackpropInputOp; + #if GOOGLE_CUDA +// Extern template instantiated in conv_grad_input_ops.cc. +extern template struct LaunchConv2DBackpropInputOp; +extern template struct LaunchConv2DBackpropInputOp; +extern template struct LaunchConv2DBackpropInputOp; + +// Extern template instantiated in depthwise_conv_op_gpu.cu.cc. extern template struct LaunchDepthwiseConvBackpropInputOp; extern template struct LaunchDepthwiseConvBackpropInputOp; @@ -548,6 +562,12 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + + // For in_depth == 1 and grouped convolutions. + use_cudnn_ = CanUseCudnn(); + cudnn_use_autotune_ = CudnnUseAutotune(); + use_cudnn_grouped_conv_ = false; + dtype_ = DataTypeToEnum::value; } void Compute(OpKernelContext* context) override { @@ -560,6 +580,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel { input_sizes.dims())); TensorShape input_shape; const int32* in_sizes_data = input_sizes.template flat().data(); + for (int i = 0; i < input_sizes.NumElements(); ++i) { OP_REQUIRES(context, in_sizes_data[i] >= 0, errors::InvalidArgument("Dimension ", i, @@ -568,27 +589,77 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel { } const TensorShape& filter_shape = filter.shape(); EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput"); + Tensor* in_backprop = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, input_shape, &in_backprop)); - auto out_backprop_ptr = out_backprop.template flat().data(); - auto filter_ptr = filter.template flat().data(); - auto in_backprop_ptr = in_backprop->template flat().data(); + // If there is nothing to compute, return. if (input_shape.num_elements() == 0) { return; } + + // If in_depth==1, this operation is just a standard convolution. + // Depthwise convolution is a special case of cuDNN's grouped convolution. + bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_); + + VLOG(2) << "DepthwiseConv2dNativeBackpropInput: " + << " Input: [" << batch << ", " << input_rows << ", " << input_cols + << ", " << in_depth << "]; Filter: [" << filter_rows << ", " + << filter_cols << ", " << in_depth << ", " << depth_multiplier + << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols + << ", " << out_depth << "], stride = " << stride_ + << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols + << ", Use cuDNN: " << use_cudnn; + + if (use_cudnn) { + // Reshape from TF depthwise filter to cuDNN grouped convolution filter: + // + // | TensorFlow | cuDNN + // -------------------------------------------------------------------- + // filter_out_depth | depth_multiplier | depth_multiplier * group_count + // filter_in_depth | in_depth | in_depth / group_count + // + // For depthwise convolution, we have group_count == in_depth. + int32 filter_in_depth = 1; + TensorShape shape = + TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth}; + Tensor reshaped_filter(/*type=*/dtype_); + OP_REQUIRES( + context, reshaped_filter.CopyFrom(filter, shape), + errors::Internal( + "Failed to reshape filter tensor for grouped convolution.")); + // TODO(yangzihao): Send in arbitrary dilation rates after the dilated + // conv is supported. + launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, + reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1, + stride_, stride_, padding_, in_backprop, data_format_); + return; + } + + auto out_backprop_ptr = out_backprop.template flat().data(); + auto filter_ptr = filter.template flat().data(); + auto in_backprop_ptr = in_backprop->template flat().data(); LaunchDepthwiseConvBackpropInputOp()( context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr, data_format_); } + protected: + bool use_cudnn_grouped_conv_; + private: std::vector strides_; Padding padding_; TensorFormat data_format_; int64 stride_; + // For in_depth == 1 and grouped convolutions. + LaunchConv2DBackpropInputOp launcher_; + bool use_cudnn_; + bool cudnn_use_autotune_; + DataType dtype_; + TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp); }; @@ -597,23 +668,52 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ DepthwiseConv2dNativeBackpropInputOp); + +TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); +#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG) TF_CALL_double(REGISTER_CPU_KERNEL); +#endif #undef REGISTER_CPU_KERNEL #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") - .Device(DEVICE_GPU) - .TypeConstraint("T") - .HostMemory("input_sizes"), - DepthwiseConv2dNativeBackpropInputOp); - -REGISTER_KERNEL_BUILDER( - Name("DepthwiseConv2dNativeBackpropInput") - .Device(DEVICE_GPU) - .TypeConstraint("T") - .HostMemory("input_sizes"), - DepthwiseConv2dNativeBackpropInputOp); + +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("input_sizes"), \ + DepthwiseConv2dNativeBackpropInputOp) + +TF_CALL_half(REGISTER_GPU_KERNEL); +TF_CALL_float(REGISTER_GPU_KERNEL); +TF_CALL_double(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +#if CUDNN_VERSION >= 7000 +template +class DepthwiseConv2dGroupedConvBackpropInputOp + : public DepthwiseConv2dNativeBackpropInputOp { + public: + DepthwiseConv2dGroupedConvBackpropInputOp(OpKernelConstruction* context) + : DepthwiseConv2dNativeBackpropInputOp(context) { + this->use_cudnn_grouped_conv_ = true; + } +}; + +#define REGISTER_GROUPED_CONV_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("input_sizes") \ + .Label("cudnn_grouped_convolution"), \ + DepthwiseConv2dGroupedConvBackpropInputOp) + +TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL); +TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL); +TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL); +#undef REGISTER_GROUPED_CONV_KERNEL +#endif // CUDNN_VERSION #endif // GOOGLE_CUDA // Kernels to compute the gradients of the filters for depthwise convolution. @@ -885,8 +985,19 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args, } } +// Extern template instantiated in conv_grad_filter_ops.cc. +extern template struct LaunchConv2DBackpropFilterOp; +extern template struct LaunchConv2DBackpropFilterOp; +extern template struct LaunchConv2DBackpropFilterOp; + #if GOOGLE_CUDA +// Extern template instantiated in conv_grad_filter_ops.cc. +extern template struct LaunchConv2DBackpropFilterOp; +extern template struct LaunchConv2DBackpropFilterOp; +extern template struct LaunchConv2DBackpropFilterOp; + +// Extern template instantiated in depthwise_conv_op_gpu.cu.cc. extern template struct LaunchDepthwiseConvBackpropFilterOp; extern template struct LaunchDepthwiseConvBackpropFilterOp; @@ -924,6 +1035,21 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + + // For in_depth == 1 and grouped convolutions. + use_cudnn_ = CanUseCudnn(); + cudnn_use_autotune_ = CudnnUseAutotune(); + use_cudnn_grouped_conv_ = false; + + if (std::is_same::value) { + dtype_ = DT_HALF; + } else if (std::is_same::value) { + dtype_ = DT_FLOAT; + } else if (std::is_same::value) { + dtype_ = DT_DOUBLE; + } else { + LOG(ERROR) << "Only half, float, and double are supported."; + } } void Compute(OpKernelContext* context) override { @@ -949,24 +1075,73 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel { OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {1}, 0, filter_shape, &filter_backprop)); - auto out_backprop_ptr = out_backprop.template flat().data(); - auto input_ptr = input.template flat().data(); - auto filter_backprop_ptr = filter_backprop->template flat().data(); // If there is nothing to compute, return. if (filter_shape.num_elements() == 0) { return; } + + // If in_depth==1, this operation is just a standard convolution. + // Depthwise convolution is a special case of cuDNN's grouped convolution. + bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_); + + VLOG(2) << "DepthwiseConv2dNativeBackpropFilter: " + << " Input: [" << batch << ", " << input_rows << ", " << input_cols + << ", " << in_depth << "]; Filter: [" << filter_rows << ", " + << filter_cols << ", " << in_depth << ", " << depth_multiplier + << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols + << ", " << out_depth << "], stride = " << stride_ + << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols + << ", Use cuDNN: " << use_cudnn; + + if (use_cudnn) { + // Reshape from TF depthwise filter to cuDNN grouped convolution filter: + // + // | TensorFlow | cuDNN + // -------------------------------------------------------------------- + // filter_out_depth | depth_multiplier | depth_multiplier * group_count + // filter_in_depth | in_depth | in_depth / group_count + // + // For depthwise convolution, we have group_count == in_depth. + int32 filter_in_depth = 1; + TensorShape shape = + TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth}; + Tensor reshaped_filter(/*type=*/dtype_); + OP_REQUIRES( + context, reshaped_filter.CopyFrom(*filter_backprop, shape), + errors::Internal( + "Failed to reshape filter tensor for grouped convolution.")); + + // TODO(yangzihao): Send in arbitrary dilation rates after the dilated + // conv is supported. + launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input, + /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_, + padding_, &reshaped_filter, data_format_); + return; + } + + auto out_backprop_ptr = out_backprop.template flat().data(); + auto input_ptr = input.template flat().data(); + auto filter_backprop_ptr = filter_backprop->template flat().data(); LaunchDepthwiseConvBackpropFilterOp()( context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr, data_format_); } + protected: + bool use_cudnn_grouped_conv_; + private: std::vector strides_; Padding padding_; TensorFormat data_format_; int64 stride_; + // For in_depth == 1 and grouped convolutions. + LaunchConv2DBackpropFilterOp launcher_; + bool use_cudnn_; + bool cudnn_use_autotune_; + DataType dtype_; + TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropFilterOp); }; @@ -976,24 +1151,50 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ DepthwiseConv2dNativeBackpropFilterOp); +TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); +#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG) TF_CALL_double(REGISTER_CPU_KERNEL); +#endif #undef REGISTER_CPU_KERNEL #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER( - Name("DepthwiseConv2dNativeBackpropFilter") - .Device(DEVICE_GPU) - .TypeConstraint("T") - .HostMemory("filter_sizes"), - DepthwiseConv2dNativeBackpropFilterOp); - -REGISTER_KERNEL_BUILDER( - Name("DepthwiseConv2dNativeBackpropFilter") - .Device(DEVICE_GPU) - .TypeConstraint("T") - .HostMemory("filter_sizes"), - DepthwiseConv2dNativeBackpropFilterOp); +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("filter_sizes"), \ + DepthwiseConv2dNativeBackpropFilterOp) + +TF_CALL_half(REGISTER_GPU_KERNEL); +TF_CALL_float(REGISTER_GPU_KERNEL); +TF_CALL_double(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +#if CUDNN_VERSION >= 7000 +template +class DepthwiseConv2dGroupedConvBackpropFilterOp + : public DepthwiseConv2dNativeBackpropFilterOp { + public: + DepthwiseConv2dGroupedConvBackpropFilterOp(OpKernelConstruction* context) + : DepthwiseConv2dNativeBackpropFilterOp(context) { + this->use_cudnn_grouped_conv_ = true; + } +}; + +#define REGISTER_GROUPED_CONV_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("filter_sizes") \ + .Label("cudnn_grouped_convolution"), \ + DepthwiseConv2dGroupedConvBackpropFilterOp) + +TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL); +TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL); +TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL); +#undef REGISTER_GROUPED_CONV_KERNEL +#endif // CUDNN_VERSION #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 6dedb1a..d5f4a68 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA +#include "cuda/include/cudnn.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -241,18 +242,22 @@ struct LaunchDepthwiseConvOp { }; // Extern template instantiated in conv_ops.cc. +extern template struct LaunchConv2DOp; extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; #if GOOGLE_CUDA +// Extern template instantiated in conv_ops.cc. +extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; + // Extern template instantiated in depthwise_conv_op_gpu.cc. extern template struct LaunchDepthwiseConvOp; extern template struct LaunchDepthwiseConvOp; extern template struct LaunchDepthwiseConvOp; -// Extern template instantiated in conv_ops.cc. -extern template struct LaunchConv2DOp; - #endif template @@ -284,9 +289,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp { "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - // For special case when in_depth == 1. + // For in_depth == 1 and grouped convolutions. use_cudnn_ = CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); + use_cudnn_grouped_conv_ = false; + dtype_ = DataTypeToEnum::value; } void Compute(OpKernelContext* context) override { @@ -357,27 +364,47 @@ class DepthwiseConv2dNativeOp : public BinaryOp { Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - VLOG(2) << "DepthwiseConv2dNative: " - << " Input: [" << batch << ", " << input_rows << ", " << input_cols - << ", " << in_depth << "]; Filter: [" << filter_rows << ", " - << filter_cols << ", " << in_depth << ", " << depth_multiplier - << "]; stride = " << stride_ << ", pad_rows = " << pad_rows - << ", pad_cols = " << pad_cols << ", output: [" << batch << ", " - << out_rows << ", " << out_cols << ", " << out_depth << "]"; - // If there is nothing to compute, return. if (out_shape.num_elements() == 0) { return; } - // If in_depth==1, this operation is just a standard convolution, so - // invoke that op. - if (std::is_same::value && in_depth == 1) { + // TODO(csigg): Have autotune decide if native is faster than cuDNN. + // If in_depth==1, this operation is just a standard convolution. + // Depthwise convolution is a special case of cuDNN's grouped convolution. + bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_); + + VLOG(2) << "DepthwiseConv2dNative: " + << " Input: [" << batch << ", " << input_rows << ", " << input_cols + << ", " << in_depth << "]; Filter: [" << filter_rows << ", " + << filter_cols << ", " << in_depth << ", " << depth_multiplier + << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols + << ", " << out_depth << "], stride = " << stride_ + << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols + << ", Use cuDNN: " << use_cudnn; + + if (use_cudnn) { + // Reshape from TF depthwise filter to cuDNN grouped convolution filter: + // + // | TensorFlow | cuDNN + // -------------------------------------------------------------------- + // filter_out_depth | depth_multiplier | depth_multiplier * group_count + // filter_in_depth | in_depth | in_depth / group_count + // + // For depthwise convolution, we have group_count == in_depth. + int32 filter_in_depth = 1; + TensorShape shape = + TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth}; + Tensor reshaped_filter(/*type=*/dtype_); + OP_REQUIRES( + context, reshaped_filter.CopyFrom(filter, shape), + errors::Internal( + "Failed to reshape filter tensor for grouped convolution.")); // TODO(yangzihao): Send in arbitrary dilation rates after the dilated // conv is supported. - launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, - /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_, - padding_, output, data_format_); + launcher_(context, use_cudnn_, cudnn_use_autotune_, input, + reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1, + stride_, stride_, padding_, output, data_format_); return; } @@ -403,6 +430,9 @@ class DepthwiseConv2dNativeOp : public BinaryOp { output_ptr, data_format_); } + protected: + bool use_cudnn_grouped_conv_; + private: std::vector strides_; Padding padding_; @@ -410,10 +440,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp { int64 stride_; // in height/width dimension. - // For the case in_depth == 1. + // For in_depth == 1 and grouped convolutions. LaunchConv2DOp launcher_; bool use_cudnn_; bool cudnn_use_autotune_; + DataType dtype_; TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp); }; @@ -421,7 +452,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp { #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint("T"), \ - DepthwiseConv2dNativeOp); + DepthwiseConv2dNativeOp) TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); @@ -430,19 +461,38 @@ TF_CALL_double(REGISTER_CPU_KERNEL); #endif #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") - .Device(DEVICE_GPU) - .TypeConstraint("T"), - DepthwiseConv2dNativeOp); - -REGISTER_KERNEL_BUILDER( - Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint("T"), - DepthwiseConv2dNativeOp); - -REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") - .Device(DEVICE_GPU) - .TypeConstraint("T"), - DepthwiseConv2dNativeOp); -#endif + +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint("T"), \ + DepthwiseConv2dNativeOp) + +TF_CALL_half(REGISTER_GPU_KERNEL); +TF_CALL_float(REGISTER_GPU_KERNEL); +TF_CALL_double(REGISTER_GPU_KERNEL); + +#if CUDNN_VERSION >= 7000 +template +class DepthwiseConv2dGroupedConvOp + : public DepthwiseConv2dNativeOp { + public: + DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context) + : DepthwiseConv2dNativeOp(context) { + this->use_cudnn_grouped_conv_ = true; + } +}; + +#define REGISTER_GROUPED_CONV_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .Label("cudnn_grouped_convolution"), \ + DepthwiseConv2dGroupedConvOp) + +TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL); +TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL); +TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL); +#endif // CUDNN_VERSION +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index f7ae1a0..659dc04 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -22,12 +22,15 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging def ConfigsToTest(): @@ -98,6 +101,7 @@ class DepthwiseConv2DTest(test.TestCase): padding, data_type, use_gpu, + grouped_conv=False, data_format="NHWC"): """Verifies the output values of the convolution function. @@ -110,25 +114,26 @@ class DepthwiseConv2DTest(test.TestCase): padding: Padding type. data_type: The data type to use. use_gpu: Whether to use GPU. + grouped_conv: Whether to use cuDNN 7's grouped convolution. data_format: The data_format of the input. "NHWC" or "NCHW". """ - total_size_1 = 1 - total_size_2 = 1 + input_size = 1 + filter_size = 1 for s in tensor_in_sizes: - total_size_1 *= s + input_size *= s for s in filter_in_sizes: - total_size_2 *= s + filter_size *= s # Initializes the input and filter tensor with numbers incrementing from 1. - x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] - x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] - with self.test_session(use_gpu=use_gpu) as sess: - if data_type == dtypes.float16: - tolerance = 1e-5 - elif data_type == dtypes.float32: - tolerance = 1e-5 - else: - self.assertEqual(data_type, dtypes.float64) - tolerance = 1e-8 + x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)] + x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)] + ops.reset_default_graph() + graph = ops.get_default_graph() + with self.test_session(graph=graph, use_gpu=use_gpu) as sess: + tolerance = { + dtypes.float16: 4e-2, + dtypes.float32: 1e-8, + dtypes.float64: 1e-13, + }[data_type] t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type) t1.set_shape(tensor_in_sizes) @@ -142,25 +147,39 @@ class DepthwiseConv2DTest(test.TestCase): native_t1 = array_ops.transpose(t1, [0, 3, 1, 2]) strides = [1, 1, stride, stride] - conv_native = nn_ops.depthwise_conv2d_native( - native_t1, - t2, - strides=strides, - data_format=data_format, - padding=padding) + with sess.graph._kernel_label_map({ + "DepthwiseConv2dNative": "cudnn_grouped_convolution" + } if grouped_conv else {}): + conv_native = nn_ops.depthwise_conv2d_native( + native_t1, + t2, + strides=strides, + data_format=data_format, + padding=padding) if data_format == "NCHW": # Transpose back from NCHW to NHWC conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1]) + try: + native_result = sess.run(conv_native) + except errors.InvalidArgumentError as e: + # Grouped convolution kernel is only registered for cuDNN 7. Silently + # return when we are running on an earlier version or without GPU. + if e.message.startswith( + "No OpKernel was registered to support Op 'DepthwiseConv2dNative'"): + tf_logging.warn("Skipping grouped convolution test") + return + raise e + conv_interface = nn_impl.depthwise_conv2d( t1, t2, strides=[1, stride, stride, 1], padding=padding) - - native_result = sess.run(conv_native) interface_result = sess.run(conv_interface) - print("data_type:", data_type, "use_gpu:", use_gpu, "max diff = ", - np.amax(np.absolute(native_result - interface_result))) + tf_logging.info( + "data_type: %r, use_gpu: %r, grouped_conv: %r, max diff = %f", + data_type, use_gpu, grouped_conv, + np.amax(np.absolute(native_result - interface_result))) self.assertArrayNear( np.ravel(native_result), np.ravel(interface_result), tolerance) self.assertShapeEqual(native_result, conv_native) @@ -169,11 +188,22 @@ class DepthwiseConv2DTest(test.TestCase): def testDepthwiseConv2D(self): for index, (input_size, filter_size, _, stride, padding) in enumerate(ConfigsToTest()): - print("Testing DepthwiseConv2D,", index, "th config:", input_size, "*", - filter_size, "stride:", stride, "padding:", padding) + tf_logging.info( + "Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: " + "%s", index, input_size, filter_size, stride, padding) for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]: + tf_logging.info("Testing without grouped_conv") self._VerifyValues( input_size, filter_size, stride, padding, data_type, use_gpu=True) + tf_logging.info("Testing with grouped_conv") + self._VerifyValues( + input_size, + filter_size, + stride, + padding, + data_type, + use_gpu=True, + grouped_conv=True) def testDepthwiseConv2DFormat(self): if not test.is_gpu_available(): @@ -181,8 +211,9 @@ class DepthwiseConv2DTest(test.TestCase): for index, (input_size, filter_size, _, stride, padding) in enumerate(ConfigsToTest()): - print("Testing DepthwiseConv2DFormat,", index, "th config:", input_size, - "*", filter_size, "stride:", stride, "padding:", padding) + tf_logging.info( + "Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, " + "padding: %s", index, input_size, filter_size, stride, padding) for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]: self._VerifyValues( input_size, @@ -226,7 +257,7 @@ class DepthwiseConv2DTest(test.TestCase): conv = nn_ops.depthwise_conv2d_native( t1, t2, strides=[1, stride, stride, 1], padding=padding) value = sess.run(conv) - print("value = ", value) + tf_logging.info("value = %r", value) self.assertArrayNear(expected, np.ravel(value), 1e-5) self.assertShapeEqual(value, conv) @@ -296,7 +327,7 @@ class DepthwiseConv2DTest(test.TestCase): expected=expected_output, use_gpu=True) - # Gradient checkers.This tests depthwise gradient computations for both + # Gradient checkers. This tests depthwise gradient computations for both # BackpropFilter and BackpropInput by comparing gradients computed by the # depthwise gradient ops with the gradients computed numerically (details can # be found in the compute_gradient_error(). @@ -310,6 +341,7 @@ class DepthwiseConv2DTest(test.TestCase): data_type, test_input, use_gpu, + grouped_conv=False, data_format="NHWC"): input_size = 1 for x in input_shape: @@ -319,14 +351,14 @@ class DepthwiseConv2DTest(test.TestCase): filter_size *= x input_data = [x * 1.0 / input_size for x in range(0, input_size)] filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)] - with self.test_session(use_gpu=use_gpu): - if data_type == dtypes.float16: - tolerance = 0.002 - elif data_type == dtypes.float32: - tolerance = 0.002 - else: - self.assertEqual(data_type, dtypes.float64) - tolerance = 1e-8 + ops.reset_default_graph() + graph = ops.get_default_graph() + with self.test_session(graph=graph, use_gpu=use_gpu) as sess: + tolerance = { + dtypes.float16: 2e-0, + dtypes.float32: 5e-4, + dtypes.float64: 1e-12, + }[data_type] input_tensor = constant_op.constant( input_data, shape=input_shape, dtype=data_type, name="input") @@ -347,35 +379,49 @@ class DepthwiseConv2DTest(test.TestCase): ] strides = [1, 1, stride, stride] - depthwise_conv2d = nn_ops.depthwise_conv2d_native( - native_input, - filter_tensor, - strides, - padding, - data_format=data_format, - name="depthwise_conv2d") + with sess.graph._kernel_label_map({ + "DepthwiseConv2dNative": "cudnn_grouped_convolution", + "DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution", + "DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution", + } if grouped_conv else {}): + depthwise_conv2d = nn_ops.depthwise_conv2d_native( + native_input, + filter_tensor, + strides, + padding, + data_format=data_format, + name="depthwise_conv2d") self.assertEqual(output_shape, depthwise_conv2d.get_shape()) - if test_input: - err = gradient_checker.compute_gradient_error( - native_input, input_shape, depthwise_conv2d, output_shape) - else: - err = gradient_checker.compute_gradient_error(filter_tensor, - filter_shape, - depthwise_conv2d, - output_shape) - print("data_type:", data_type, "use_gpu:", use_gpu, ", error = ", err) + + try: + if test_input: + err = gradient_checker.compute_gradient_error( + native_input, input_shape, depthwise_conv2d, output_shape) + else: + err = gradient_checker.compute_gradient_error( + filter_tensor, filter_shape, depthwise_conv2d, output_shape) + except errors.InvalidArgumentError as e: + # Grouped convolution kernel is only registered for cuDNN 7. Silently + # return when we are running on an earlier version or without GPU. + if grouped_conv and e.message.startswith( + "No OpKernel was registered to support Op 'DepthwiseConv2dNative'"): + tf_logging.warn("Skipping grouped convolution test") + return + raise e + + tf_logging.info( + "data_type: %r, use_gpu: %r, grouped_conv: %r, error = %f", data_type, + use_gpu, grouped_conv, err) self.assertLess(err, tolerance) def testDepthwiseConv2DInputGrad(self): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(CheckGradConfigsToTest()): - print("Testing DepthwiseConv2DInputGrad,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) - # Note: float16 test for DepthwiseConv2DInputGrad is not enabled, - # calculations are not very precise. - for data_type in [dtypes.float32, dtypes.float64]: + tf_logging.info( + "Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, " + "padding: %s", index, input_size, filter_size, stride, padding) + for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]: self._ConstructAndTestGradient( input_size, filter_size, @@ -385,6 +431,16 @@ class DepthwiseConv2DTest(test.TestCase): data_type, test_input=True, use_gpu=True) + self._ConstructAndTestGradient( + input_size, + filter_size, + output_size, + stride, + padding, + data_type, + test_input=True, + use_gpu=True, + grouped_conv=True) def testDepthwiseConv2DInputGradFormat(self): if not test.is_gpu_available(): @@ -392,12 +448,11 @@ class DepthwiseConv2DTest(test.TestCase): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(CheckGradConfigsToTest()): - print("Testing DepthwiseConv2DInputGradFormat,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) - # Note: float16 test for DepthwiseConv2DInputGradFormat is not enabled, - # calculations are not very precise. - for data_type in [dtypes.float32, dtypes.float64]: + tf_logging.info( + "Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, " + "stride: %d, padding: %s", index, input_size, filter_size, stride, + padding) + for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]: self._ConstructAndTestGradient( input_size, filter_size, @@ -412,12 +467,10 @@ class DepthwiseConv2DTest(test.TestCase): def testDepthwiseConv2DFilterGrad(self): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(CheckGradConfigsToTest()): - print("Testing DepthwiseConv2DFilterGrad,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) - # Note: float16 test for DepthwiseConv2DFilterGrad is not enabled, - # calculations are not very precise. - for data_type in [dtypes.float32, dtypes.float64]: + tf_logging.info( + "Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: " + "%d, padding: %s", index, input_size, filter_size, stride, padding) + for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]: self._ConstructAndTestGradient( input_size, filter_size, @@ -434,12 +487,11 @@ class DepthwiseConv2DTest(test.TestCase): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(CheckGradConfigsToTest()): - print("Testing DepthwiseConv2DFilterGradFormat,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) - # Note: float16 test for DepthwiseConv2DFilterGradFormat is not enabled, - # calculations are not very precise. - for data_type in [dtypes.float32, dtypes.float64]: + tf_logging.info( + "Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, " + "stride: %d, padding: %s", index, input_size, filter_size, stride, + padding) + for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]: self._ConstructAndTestGradient( input_size, filter_size, @@ -494,9 +546,10 @@ class DepthwiseConv2DTest(test.TestCase): def testDepthwiseConv2DInputGradCompare(self): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(ConfigsToTest()): - print("Testing DepthwiseConv2DInputGradCompare,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) + tf_logging.info( + "Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, " + "stride: %d, padding: %s", index, input_size, filter_size, stride, + padding) self._CompareBackpropInputFloat(input_size, filter_size, output_size, stride, padding) self._CompareBackpropInputDouble(input_size, filter_size, output_size, @@ -545,9 +598,10 @@ class DepthwiseConv2DTest(test.TestCase): def testDepthwiseConv2DFilterGradCompare(self): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(ConfigsToTest()): - print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) + tf_logging.info( + "Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, " + "stride: %d, padding: %s", index, input_size, filter_size, stride, + padding) self._CompareBackpropFilterFloat(input_size, filter_size, output_size, stride, padding) self._CompareBackpropFilterDouble(input_size, filter_size, output_size, diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 42a77aa..773cac2 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -337,7 +337,9 @@ CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM( #if CUDNN_VERSION >= 7000 #define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ __macro(cudnnSetConvolutionMathType) \ - __macro(cudnnSetRNNMatrixMathType) + __macro(cudnnSetRNNMatrixMathType) \ + __macro(cudnnSetConvolutionGroupCount) \ + __macro(cudnnGetConvolutionGroupCount) // clang-format on CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP) @@ -779,6 +781,20 @@ class ScopedConvolutionDescriptor { // NOTE(benbarsdell): This only applies if tensor op math is enabled // and algo selection is set to Default. this->set_use_tensor_op_math(true); + +#if CUDNN_MAJOR >= 7 + VLOG(2) << "Requesting grouped convolution: " + << convolution_descriptor.group_count(); + status = wrap::cudnnSetConvolutionGroupCount( + parent_, handle_, convolution_descriptor.group_count()); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(FATAL) << "could not set cudnn convolution group count: " + << ToString(status); + } +#else + CHECK_EQ(convolution_descriptor.group_count(), 1) + << "Requested grouped convolution for cuDNN version < 7"; +#endif } void set_use_tensor_op_math(bool use_tensor_op_math) { diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 031c82d..eed93ef 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -434,6 +434,7 @@ ConvolutionDescriptor::ConvolutionDescriptor(int ndims) filter_strides_(ndims, 1), dilation_rates_(ndims, 1), pad_alignment_(PadAlignment::kDefault), + group_count_(1), ndims_(ndims) {} ConvolutionDescriptor::ConvolutionDescriptor() diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 0c2e083..18606eb 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -543,6 +543,10 @@ class ConvolutionDescriptor { pad_alignment_ = pad_alignment; return *this; } + ConvolutionDescriptor& set_group_count(int group_count) { + group_count_ = group_count; + return *this; + } int64 zero_padding_height() const { return GetDim(zero_padding_, DimIndex::Y); } @@ -566,6 +570,7 @@ class ConvolutionDescriptor { int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); } int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); } PadAlignment pad_alignment() const { return pad_alignment_; } + int group_count() const { return group_count_; } int ndims() const { return ndims_; } std::vector strides() const { return filter_strides_; } @@ -578,6 +583,7 @@ class ConvolutionDescriptor { std::vector filter_strides_; std::vector dilation_rates_; PadAlignment pad_alignment_; + int group_count_; int ndims_; // TODO(leary) cudnn provides these fields, but need to characterize what // their effect is -- they may be boolean rather than integral. -- 2.7.4