"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
- ] + if_cuda(["@cub_archive//:cub"]),
+ ] + if_cuda([
+ "@cub_archive//:cub",
+ "@local_config_cuda//cuda:cudnn",
+ ]),
)
tf_kernel_library(
prefix = "depthwise_conv_grad_op",
deps = [
":bounds_check",
+ ":conv_ops",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
- ],
+ ] + if_cuda([
+ "@local_config_cuda//cuda:cudnn",
+ ]),
)
cc_library(
struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& input,
- int row_stride, int col_stride, const Padding& padding,
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding,
Tensor* filter_backprop, TensorFormat data_format) {
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
#endif
LaunchConv2DBackpropFilterOp<Device, T>()(
- context, false, false, out_backprop, input, dims.spatial_dims[0].stride,
+ context, false, false, out_backprop, input,
+ /*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride,
dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_);
}
TF_CALL_double(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
+
// GPU definitions.
#if GOOGLE_CUDA
// The slow version (but compiles for GPU)
return;
}
+ // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
+ // input depth, it's a depthwise convolution. More generally, if the filter
+ // in-depth divides but is smaller than the input depth, it is a grouped
+ // convolution.
+ bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
if (!cudnn_disable_conv_1x1_optimization_ &&
dims.spatial_dims[0].filter_size == 1 &&
- dims.spatial_dims[1].filter_size == 1 &&
+ dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
data_format == FORMAT_NHWC) {
const uint64 m = dims.in_depth;
dims.spatial_dims[0].input_size &&
dims.spatial_dims[1].filter_size ==
dims.spatial_dims[1].input_size &&
- padding == VALID && data_format == FORMAT_NHWC) {
- // The input data and filter have the same height/width, so call cublas
- // directly.
+ !is_grouped_convolution && padding == VALID &&
+ data_format == FORMAT_NHWC) {
+ // The input data and filter have the same height/width, and we are not
+ // using grouped convolution, so call cublas directly.
const uint64 m = dims.spatial_dims[0].input_size *
dims.spatial_dims[1].input_size * dims.in_depth;
const uint64 k = dims.batch_size;
se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
.set_input_filter_width(dims.spatial_dims[1].filter_size)
- .set_input_feature_map_count(dims.in_depth)
- .set_output_feature_map_count(dims.out_depth);
+ .set_input_feature_map_count(filter_shape.dim_size(2))
+ .set_output_feature_map_count(filter_shape.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
.set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
.set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
+ .set_zero_padding_width(padding_cols / 2)
+ .set_group_count(dims.in_depth / filter_shape.dim_size(2));
// NOTE(zhengxq):
// cuDNN only supports the following layouts :
int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
ConvParameters conv_parameters = {
- dims.batch_size, // batch
- dims.in_depth, // in_depths
- {{input_desc.height(), // in_rows
- input_desc.width()}}, // in_cols
- dims.out_depth, // out_depths
- {{dims.spatial_dims[0].filter_size, // filter_rows
- dims.spatial_dims[1].filter_size}}, // filter_cols
- {{dims.spatial_dims[0].dilation, // dilation_rows
- dims.spatial_dims[1].dilation}}, // dilation_cols
- {{dims.spatial_dims[0].stride, // stride_rows
- dims.spatial_dims[1].stride}}, // stride_cols
- {{padding_rows, // padding_rows
- padding_cols}}, // padding_cols
- dtype, // tensor datatype
- device_id, // device_id
+ dims.batch_size, // batch
+ dims.in_depth, // in_depths
+ {{input_desc.height(), // in_rows
+ input_desc.width()}}, // in_cols
+ dims.out_depth, // out_depths
+ {{dims.spatial_dims[0].filter_size, // filter_rows
+ dims.spatial_dims[1].filter_size, // filter_cols
+ filter_shape.dim_size(2)}}, // filter_depth
+ {{dims.spatial_dims[0].dilation, // dilation_rows
+ dims.spatial_dims[1].dilation}}, // dilation_cols
+ {{dims.spatial_dims[0].stride, // stride_rows
+ dims.spatial_dims[1].stride}}, // stride_cols
+ {{padding_rows, // padding_rows
+ padding_cols}}, // padding_cols
+ dtype, // tensor datatype
+ device_id, // device_id
};
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>;
-DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
.TypeConstraint<Eigen::half>("T")
.HostMemory("filter_sizes"),
Conv2DSlowBackpropFilterOp<GPUDevice, Eigen::half>);
+
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
+
#endif // GOOGLE_CUDA
} // namespace tensorflow
struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& filter,
- int row_stride, int col_stride, const Padding& padding,
- Tensor* in_backprop, TensorFormat data_format) {
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding, Tensor* in_backprop,
+ TensorFormat data_format) {
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
LaunchConv2DBackpropInputOp<Device, T>()(
context, false, false, out_backprop, filter,
- dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_,
- in_backprop, data_format_);
+ /*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride,
+ dims.spatial_dims[1].stride, padding_, in_backprop, data_format_);
}
private:
TF_CALL_double(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
+template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
+
// GPU definitions.
#if GOOGLE_CUDA
// The slow version (but compiles for GPU)
return;
}
+ // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
+ // input depth, it's a depthwise convolution. More generally, if the filter
+ // in-depth divides but is smaller than the input depth, it is a grouped
+ // convolution.
+ bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
if (dims.spatial_dims[0].filter_size == 1 &&
- dims.spatial_dims[1].filter_size == 1 &&
+ dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
dims.spatial_dims[0].input_size &&
dims.spatial_dims[1].filter_size ==
dims.spatial_dims[1].input_size &&
- padding == VALID && data_format == FORMAT_NHWC) {
- // The input data and filter have the same height/width, so call cublas
- // directly.
+ !is_grouped_convolution && padding == VALID &&
+ data_format == FORMAT_NHWC) {
+ // The input data and filter have the same height/width, and we are not
+ // using grouped convolution, so call cublas directly.
const uint64 m = dims.batch_size;
const uint64 k = dims.out_depth;
const uint64 n = dims.spatial_dims[0].input_size *
se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
.set_input_filter_width(dims.spatial_dims[1].filter_size)
- .set_input_feature_map_count(dims.in_depth)
- .set_output_feature_map_count(dims.out_depth);
+ .set_input_feature_map_count(filter_shape.dim_size(2))
+ .set_output_feature_map_count(filter_shape.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
.set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
.set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
+ .set_zero_padding_width(padding_cols / 2)
+ .set_group_count(dims.in_depth / filter_shape.dim_size(2));
// NOTE(keveman):
// cuDNN only supports the following layouts :
int device_id = stream->parent()->device_ordinal();
DataType dtype = out_backprop.dtype();
ConvParameters conv_parameters = {
- dims.batch_size, // batch
- dims.in_depth, // in_depths
- {{input_desc.height(), // in_rows
- input_desc.width()}}, // in_cols
- dims.out_depth, // out_depths
- {{dims.spatial_dims[0].filter_size, // filter_rows
- dims.spatial_dims[1].filter_size}}, // filter_cols
- {{dims.spatial_dims[0].dilation, // dilation_rows
- dims.spatial_dims[1].dilation}}, // dilation_cols
- {{dims.spatial_dims[0].stride, // stride_rows
- dims.spatial_dims[1].stride}}, // stride_cols
- {{padding_rows, // padding_rows
- padding_cols}}, // padding_cols
- dtype, // tensor data type
- device_id, // device_id
+ dims.batch_size, // batch
+ dims.in_depth, // in_depths
+ {{input_desc.height(), // in_rows
+ input_desc.width()}}, // in_cols
+ dims.out_depth, // out_depths
+ {{dims.spatial_dims[0].filter_size, // filter_rows
+ dims.spatial_dims[1].filter_size, // filter_cols
+ filter_shape.dim_size(2)}}, // filter_depths
+ {{dims.spatial_dims[0].dilation, // dilation_rows
+ dims.spatial_dims[1].dilation}}, // dilation_cols
+ {{dims.spatial_dims[0].stride, // stride_rows
+ dims.spatial_dims[1].stride}}, // stride_cols
+ {{padding_rows, // padding_rows
+ padding_cols}}, // padding_cols
+ dtype, // tensor data type
+ device_id, // device_id
};
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>;
-DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
.TypeConstraint<Eigen::half>("T")
.HostMemory("input_sizes"),
Conv2DSlowBackpropInputOp<GPUDevice, Eigen::half>);
+
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
+template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
+
#endif // GOOGLE_CUDA
} // namespace tensorflow
dims->in_depth = input_shape.dim_size(feature_dim);
// The input and output feature dimensions are the second last and last
// dimensions of the filter Tensor.
- if (dims->in_depth != filter_shape.dim_size(num_dims - 2)) {
+ VLOG(2) << "input vs filter_in depth " << dims->in_depth << " "
+ << filter_shape.dim_size(num_dims - 2);
+ if (dims->in_depth % filter_shape.dim_size(num_dims - 2)) {
return errors::InvalidArgument(
- label, ": input and filter must have the same depth");
+ label, ": input depth must be evenly divisible by filter depth");
}
dims->out_depth = filter_shape.dim_size(num_dims - 1);
if (dims->out_depth != out_backprop_shape.dim_size(feature_dim)) {
return errors::InvalidArgument(
label, ": filter and out_backprop must have the same out_depth");
}
-
dims->spatial_dims.resize(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
int image_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
#define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
#include "tensorflow/core/kernels/conv_ops.h"
+
#include <string.h>
#include <map>
#include <vector>
+
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/deep_conv2d.h"
#include "tensorflow/core/kernels/ops_util.h"
-#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
-#include "tensorflow/core/kernels/xsmm_conv2d.h"
-#endif
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
+#include "tensorflow/core/kernels/xsmm_conv2d.h"
+#endif
+
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
"NHWC tensor format for now."));
return;
}
+ const int64 in_depth = GetTensorDim(input, data_format, 'C');
+ OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
+ errors::Unimplemented("Generic conv implementation does not "
+ "support grouped convolutions for now."));
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
row_dilation, col_dilation, padding, output,
data_format);
}
// The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth.
+ // filter's in_depth or be evenly divisible by filter's in_depth.
const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- OP_REQUIRES(context, in_depth == filter.dim_size(2),
+ const int64 patch_depth = filter.dim_size(2);
+ OP_REQUIRES(context, in_depth % patch_depth == 0,
errors::InvalidArgument(
- "input and filter must have the same depth: ", in_depth,
- " vs ", filter.dim_size(2)));
+ "input depth must be evenly divisible by filter depth: ",
+ in_depth, " vs ", patch_depth));
// The last dimension for filter is out_depth.
const int out_depth = static_cast<int>(filter.dim_size(3));
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
VLOG(2) << "Conv2D: in_depth = " << in_depth
+ << ", patch_depth = " << patch_depth
<< ", input_cols = " << input_cols
<< ", filter_cols = " << filter_cols
<< ", input_rows = " << input_rows
#endif // USE_GEMM_FOR_CONV
// To be used inside depthwise_conv_op.cc.
+template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
template struct LaunchConv2DOp<CPUDevice, float>;
+template struct LaunchConv2DOp<CPUDevice, double>;
#if GOOGLE_CUDA
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
}
Tensor input = input_param;
-
- if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 &&
- col_dilation == 1 && row_stride == 1 && col_stride == 1 &&
- data_format == FORMAT_NHWC) {
+ const int64 in_batch = GetTensorDim(input, data_format, 'N');
+ int64 in_rows = GetTensorDim(input, data_format, 'H');
+ int64 in_cols = GetTensorDim(input, data_format, 'W');
+ const int64 in_depths = GetTensorDim(input, data_format, 'C');
+ const int64 patch_rows = filter.dim_size(0);
+ const int64 patch_cols = filter.dim_size(1);
+ const int64 patch_depths = filter.dim_size(2);
+
+ // If the filter in-depth (patch_depths) is 1 and smaller than the input
+ // depth, it's a depthwise convolution. More generally, if the filter in-depth
+ // divides but is smaller than the input depth, it is a grouped convolution.
+ bool is_grouped_convolution = patch_depths != in_depths;
+ if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
+ row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
+ col_stride == 1 && data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
- const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
- const uint64 k = filter.dim_size(2);
+ const uint64 m = in_batch * in_rows * in_cols;
+ const uint64 k = patch_depths;
const uint64 n = filter.dim_size(3);
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
", n=", n, ", k=", k));
}
return;
- } else if (filter.dim_size(0) == input.dim_size(1) &&
- filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
+ } else if (patch_rows == in_rows && patch_cols == in_cols &&
+ !is_grouped_convolution && row_dilation == 1 &&
col_dilation == 1 && padding == VALID &&
data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, so call cublas
// directly.
- const uint64 m = input.dim_size(0);
- const uint64 k =
- filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
+ const uint64 m = in_batch;
+ const uint64 k = patch_rows * patch_cols * patch_depths;
const uint64 n = filter.dim_size(3);
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
int padding_rows = 0;
int padding_cols = 0;
- const int64 in_batch = GetTensorDim(input, data_format, 'N');
- int64 in_rows = GetTensorDim(input, data_format, 'H');
- int64 in_cols = GetTensorDim(input, data_format, 'W');
- const int64 in_depths = GetTensorDim(input, data_format, 'C');
const int64 out_batch = GetTensorDim(*output, data_format, 'N');
const int64 out_rows = GetTensorDim(*output, data_format, 'H');
const int64 out_cols = GetTensorDim(*output, data_format, 'W');
const int64 out_depths = GetTensorDim(*output, data_format, 'C');
- const int64 patch_rows = filter.dim_size(0);
- const int64 patch_cols = filter.dim_size(1);
if (padding == SAME) {
// Total padding on rows and cols is
// Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R
.set_feature_map_count(out_depths)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc;
- filter_desc.set_input_filter_height(filter.dim_size(0))
- .set_input_filter_width(filter.dim_size(1))
- .set_input_feature_map_count(filter.dim_size(2))
+ filter_desc.set_input_filter_height(patch_rows)
+ .set_input_filter_width(patch_cols)
+ .set_input_feature_map_count(patch_depths)
.set_output_feature_map_count(filter.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(row_dilation)
.set_vertical_filter_stride(row_stride)
.set_horizontal_filter_stride(col_stride)
.set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
+ .set_zero_padding_width(padding_cols / 2)
+ .set_group_count(in_depths / patch_depths);
Tensor transformed_filter;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
in_cols}}, // in_cols
out_depths, // out_depths
{{patch_rows, // filter_rows
- patch_cols}}, // filter_cols
+ patch_cols, // filter_cols
+ patch_depths}}, // filter_depths
{{row_dilation, // dilation_rows
col_dilation}}, // dilation_cols
{{row_stride, // stride_rows
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>
-DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
Conv2DOp<GPUDevice, double>);
// To be used inside depthwise_conv_op.cc.
-template class LaunchConv2DOp<GPUDevice, float>;
+template struct LaunchConv2DOp<GPUDevice, float>;
+template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DOp<GPUDevice, double>;
#endif // GOOGLE_CUDA
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
+#include "cuda/include/cudnn.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
}
}
+// Extern template instantiated in conv_grad_input_ops.cc.
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
+
#if GOOGLE_CUDA
+// Extern template instantiated in conv_grad_input_ops.cc.
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
+
+// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+
+ // For in_depth == 1 and grouped convolutions.
+ use_cudnn_ = CanUseCudnn();
+ cudnn_use_autotune_ = CudnnUseAutotune();
+ use_cudnn_grouped_conv_ = false;
+ dtype_ = DataTypeToEnum<T>::value;
}
void Compute(OpKernelContext* context) override {
input_sizes.dims()));
TensorShape input_shape;
const int32* in_sizes_data = input_sizes.template flat<int32>().data();
+
for (int i = 0; i < input_sizes.NumElements(); ++i) {
OP_REQUIRES(context, in_sizes_data[i] >= 0,
errors::InvalidArgument("Dimension ", i,
}
const TensorShape& filter_shape = filter.shape();
EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput");
+
Tensor* in_backprop = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, input_shape, &in_backprop));
- auto out_backprop_ptr = out_backprop.template flat<T>().data();
- auto filter_ptr = filter.template flat<T>().data();
- auto in_backprop_ptr = in_backprop->template flat<T>().data();
+
// If there is nothing to compute, return.
if (input_shape.num_elements() == 0) {
return;
}
+
+ // If in_depth==1, this operation is just a standard convolution.
+ // Depthwise convolution is a special case of cuDNN's grouped convolution.
+ bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+ VLOG(2) << "DepthwiseConv2dNativeBackpropInput: "
+ << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+ << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+ << filter_cols << ", " << in_depth << ", " << depth_multiplier
+ << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+ << ", " << out_depth << "], stride = " << stride_
+ << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+ << ", Use cuDNN: " << use_cudnn;
+
+ if (use_cudnn) {
+ // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+ //
+ // | TensorFlow | cuDNN
+ // --------------------------------------------------------------------
+ // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+ // filter_in_depth | in_depth | in_depth / group_count
+ //
+ // For depthwise convolution, we have group_count == in_depth.
+ int32 filter_in_depth = 1;
+ TensorShape shape =
+ TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+ Tensor reshaped_filter(/*type=*/dtype_);
+ OP_REQUIRES(
+ context, reshaped_filter.CopyFrom(filter, shape),
+ errors::Internal(
+ "Failed to reshape filter tensor for grouped convolution."));
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop,
+ reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
+ stride_, stride_, padding_, in_backprop, data_format_);
+ return;
+ }
+
+ auto out_backprop_ptr = out_backprop.template flat<T>().data();
+ auto filter_ptr = filter.template flat<T>().data();
+ auto in_backprop_ptr = in_backprop->template flat<T>().data();
LaunchDepthwiseConvBackpropInputOp<Device, T>()(
context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr,
data_format_);
}
+ protected:
+ bool use_cudnn_grouped_conv_;
+
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
int64 stride_;
+ // For in_depth == 1 and grouped convolutions.
+ LaunchConv2DBackpropInputOp<Device, T> launcher_;
+ bool use_cudnn_;
+ bool cudnn_use_autotune_;
+ DataType dtype_;
+
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp);
};
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
DepthwiseConv2dNativeBackpropInputOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
+#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
TF_CALL_double(REGISTER_CPU_KERNEL);
+#endif
#undef REGISTER_CPU_KERNEL
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .HostMemory("input_sizes"),
- DepthwiseConv2dNativeBackpropInputOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(
- Name("DepthwiseConv2dNativeBackpropInput")
- .Device(DEVICE_GPU)
- .TypeConstraint<double>("T")
- .HostMemory("input_sizes"),
- DepthwiseConv2dNativeBackpropInputOp<GPUDevice, double>);
+
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("input_sizes"), \
+ DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvBackpropInputOp
+ : public DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T> {
+ public:
+ DepthwiseConv2dGroupedConvBackpropInputOp(OpKernelConstruction* context)
+ : DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>(context) {
+ this->use_cudnn_grouped_conv_ = true;
+ }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("input_sizes") \
+ .Label("cudnn_grouped_convolution"), \
+ DepthwiseConv2dGroupedConvBackpropInputOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#undef REGISTER_GROUPED_CONV_KERNEL
+#endif // CUDNN_VERSION
#endif // GOOGLE_CUDA
// Kernels to compute the gradients of the filters for depthwise convolution.
}
}
+// Extern template instantiated in conv_grad_filter_ops.cc.
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
+
#if GOOGLE_CUDA
+// Extern template instantiated in conv_grad_filter_ops.cc.
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
+
+// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+
+ // For in_depth == 1 and grouped convolutions.
+ use_cudnn_ = CanUseCudnn();
+ cudnn_use_autotune_ = CudnnUseAutotune();
+ use_cudnn_grouped_conv_ = false;
+
+ if (std::is_same<T, Eigen::half>::value) {
+ dtype_ = DT_HALF;
+ } else if (std::is_same<T, float>::value) {
+ dtype_ = DT_FLOAT;
+ } else if (std::is_same<T, double>::value) {
+ dtype_ = DT_DOUBLE;
+ } else {
+ LOG(ERROR) << "Only half, float, and double are supported.";
+ }
}
void Compute(OpKernelContext* context) override {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{1}, 0, filter_shape, &filter_backprop));
- auto out_backprop_ptr = out_backprop.template flat<T>().data();
- auto input_ptr = input.template flat<T>().data();
- auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
// If there is nothing to compute, return.
if (filter_shape.num_elements() == 0) {
return;
}
+
+ // If in_depth==1, this operation is just a standard convolution.
+ // Depthwise convolution is a special case of cuDNN's grouped convolution.
+ bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+ VLOG(2) << "DepthwiseConv2dNativeBackpropFilter: "
+ << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+ << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+ << filter_cols << ", " << in_depth << ", " << depth_multiplier
+ << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+ << ", " << out_depth << "], stride = " << stride_
+ << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+ << ", Use cuDNN: " << use_cudnn;
+
+ if (use_cudnn) {
+ // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+ //
+ // | TensorFlow | cuDNN
+ // --------------------------------------------------------------------
+ // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+ // filter_in_depth | in_depth | in_depth / group_count
+ //
+ // For depthwise convolution, we have group_count == in_depth.
+ int32 filter_in_depth = 1;
+ TensorShape shape =
+ TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+ Tensor reshaped_filter(/*type=*/dtype_);
+ OP_REQUIRES(
+ context, reshaped_filter.CopyFrom(*filter_backprop, shape),
+ errors::Internal(
+ "Failed to reshape filter tensor for grouped convolution."));
+
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
+ /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
+ padding_, &reshaped_filter, data_format_);
+ return;
+ }
+
+ auto out_backprop_ptr = out_backprop.template flat<T>().data();
+ auto input_ptr = input.template flat<T>().data();
+ auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
LaunchDepthwiseConvBackpropFilterOp<Device, T>()(
context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
data_format_);
}
+ protected:
+ bool use_cudnn_grouped_conv_;
+
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
int64 stride_;
+ // For in_depth == 1 and grouped convolutions.
+ LaunchConv2DBackpropFilterOp<Device, T> launcher_;
+ bool use_cudnn_;
+ bool cudnn_use_autotune_;
+ DataType dtype_;
+
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropFilterOp);
};
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
DepthwiseConv2dNativeBackpropFilterOp<CPUDevice, T>);
+TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
+#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
TF_CALL_double(REGISTER_CPU_KERNEL);
+#endif
#undef REGISTER_CPU_KERNEL
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(
- Name("DepthwiseConv2dNativeBackpropFilter")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .HostMemory("filter_sizes"),
- DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(
- Name("DepthwiseConv2dNativeBackpropFilter")
- .Device(DEVICE_GPU)
- .TypeConstraint<double>("T")
- .HostMemory("filter_sizes"),
- DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, double>);
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("filter_sizes"), \
+ DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvBackpropFilterOp
+ : public DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T> {
+ public:
+ DepthwiseConv2dGroupedConvBackpropFilterOp(OpKernelConstruction* context)
+ : DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>(context) {
+ this->use_cudnn_grouped_conv_ = true;
+ }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("filter_sizes") \
+ .Label("cudnn_grouped_convolution"), \
+ DepthwiseConv2dGroupedConvBackpropFilterOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#undef REGISTER_GROUPED_CONV_KERNEL
+#endif // CUDNN_VERSION
#endif // GOOGLE_CUDA
} // namespace tensorflow
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
+#include "cuda/include/cudnn.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
};
// Extern template instantiated in conv_ops.cc.
+extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
extern template struct LaunchConv2DOp<CPUDevice, float>;
+extern template struct LaunchConv2DOp<CPUDevice, double>;
#if GOOGLE_CUDA
+// Extern template instantiated in conv_ops.cc.
+extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DOp<GPUDevice, float>;
+extern template struct LaunchConv2DOp<GPUDevice, double>;
+
// Extern template instantiated in depthwise_conv_op_gpu.cc.
extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
-// Extern template instantiated in conv_ops.cc.
-extern template struct LaunchConv2DOp<GPUDevice, float>;
-
#endif
template <typename Device, typename T>
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- // For special case when in_depth == 1.
+ // For in_depth == 1 and grouped convolutions.
use_cudnn_ = CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
+ use_cudnn_grouped_conv_ = false;
+ dtype_ = DataTypeToEnum<T>::value;
}
void Compute(OpKernelContext* context) override {
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- VLOG(2) << "DepthwiseConv2dNative: "
- << " Input: [" << batch << ", " << input_rows << ", " << input_cols
- << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
- << filter_cols << ", " << in_depth << ", " << depth_multiplier
- << "]; stride = " << stride_ << ", pad_rows = " << pad_rows
- << ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
- << out_rows << ", " << out_cols << ", " << out_depth << "]";
-
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
return;
}
- // If in_depth==1, this operation is just a standard convolution, so
- // invoke that op.
- if (std::is_same<T, float>::value && in_depth == 1) {
+ // TODO(csigg): Have autotune decide if native is faster than cuDNN.
+ // If in_depth==1, this operation is just a standard convolution.
+ // Depthwise convolution is a special case of cuDNN's grouped convolution.
+ bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+ VLOG(2) << "DepthwiseConv2dNative: "
+ << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+ << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+ << filter_cols << ", " << in_depth << ", " << depth_multiplier
+ << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+ << ", " << out_depth << "], stride = " << stride_
+ << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+ << ", Use cuDNN: " << use_cudnn;
+
+ if (use_cudnn) {
+ // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+ //
+ // | TensorFlow | cuDNN
+ // --------------------------------------------------------------------
+ // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+ // filter_in_depth | in_depth | in_depth / group_count
+ //
+ // For depthwise convolution, we have group_count == in_depth.
+ int32 filter_in_depth = 1;
+ TensorShape shape =
+ TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+ Tensor reshaped_filter(/*type=*/dtype_);
+ OP_REQUIRES(
+ context, reshaped_filter.CopyFrom(filter, shape),
+ errors::Internal(
+ "Failed to reshape filter tensor for grouped convolution."));
// TODO(yangzihao): Send in arbitrary dilation rates after the dilated
// conv is supported.
- launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
- padding_, output, data_format_);
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, input,
+ reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
+ stride_, stride_, padding_, output, data_format_);
return;
}
output_ptr, data_format_);
}
+ protected:
+ bool use_cudnn_grouped_conv_;
+
private:
std::vector<int32> strides_;
Padding padding_;
int64 stride_; // in height/width dimension.
- // For the case in_depth == 1.
+ // For in_depth == 1 and grouped convolutions.
LaunchConv2DOp<Device, T> launcher_;
bool use_cudnn_;
bool cudnn_use_autotune_;
+ DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
};
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- DepthwiseConv2dNativeOp<CPUDevice, T>);
+ DepthwiseConv2dNativeOp<CPUDevice, T>)
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
#endif
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
- .Device(DEVICE_GPU)
- .TypeConstraint<Eigen::half>("T"),
- DepthwiseConv2dNativeOp<GPUDevice, Eigen::half>);
-
-REGISTER_KERNEL_BUILDER(
- Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- DepthwiseConv2dNativeOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
- .Device(DEVICE_GPU)
- .TypeConstraint<double>("T"),
- DepthwiseConv2dNativeOp<GPUDevice, double>);
-#endif
+
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ DepthwiseConv2dNativeOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvOp
+ : public DepthwiseConv2dNativeOp<GPUDevice, T> {
+ public:
+ DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context)
+ : DepthwiseConv2dNativeOp<GPUDevice, T>(context) {
+ this->use_cudnn_grouped_conv_ = true;
+ }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .Label("cudnn_grouped_convolution"), \
+ DepthwiseConv2dGroupedConvOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#endif // CUDNN_VERSION
+#endif // GOOGLE_CUDA
} // namespace tensorflow
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
def ConfigsToTest():
padding,
data_type,
use_gpu,
+ grouped_conv=False,
data_format="NHWC"):
"""Verifies the output values of the convolution function.
padding: Padding type.
data_type: The data type to use.
use_gpu: Whether to use GPU.
+ grouped_conv: Whether to use cuDNN 7's grouped convolution.
data_format: The data_format of the input. "NHWC" or "NCHW".
"""
- total_size_1 = 1
- total_size_2 = 1
+ input_size = 1
+ filter_size = 1
for s in tensor_in_sizes:
- total_size_1 *= s
+ input_size *= s
for s in filter_in_sizes:
- total_size_2 *= s
+ filter_size *= s
# Initializes the input and filter tensor with numbers incrementing from 1.
- x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
- x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
- with self.test_session(use_gpu=use_gpu) as sess:
- if data_type == dtypes.float16:
- tolerance = 1e-5
- elif data_type == dtypes.float32:
- tolerance = 1e-5
- else:
- self.assertEqual(data_type, dtypes.float64)
- tolerance = 1e-8
+ x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)]
+ x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
+ ops.reset_default_graph()
+ graph = ops.get_default_graph()
+ with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ tolerance = {
+ dtypes.float16: 4e-2,
+ dtypes.float32: 1e-8,
+ dtypes.float64: 1e-13,
+ }[data_type]
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
t1.set_shape(tensor_in_sizes)
native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
strides = [1, 1, stride, stride]
- conv_native = nn_ops.depthwise_conv2d_native(
- native_t1,
- t2,
- strides=strides,
- data_format=data_format,
- padding=padding)
+ with sess.graph._kernel_label_map({
+ "DepthwiseConv2dNative": "cudnn_grouped_convolution"
+ } if grouped_conv else {}):
+ conv_native = nn_ops.depthwise_conv2d_native(
+ native_t1,
+ t2,
+ strides=strides,
+ data_format=data_format,
+ padding=padding)
if data_format == "NCHW":
# Transpose back from NCHW to NHWC
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
+ try:
+ native_result = sess.run(conv_native)
+ except errors.InvalidArgumentError as e:
+ # Grouped convolution kernel is only registered for cuDNN 7. Silently
+ # return when we are running on an earlier version or without GPU.
+ if e.message.startswith(
+ "No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
+ tf_logging.warn("Skipping grouped convolution test")
+ return
+ raise e
+
conv_interface = nn_impl.depthwise_conv2d(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
-
- native_result = sess.run(conv_native)
interface_result = sess.run(conv_interface)
- print("data_type:", data_type, "use_gpu:", use_gpu, "max diff = ",
- np.amax(np.absolute(native_result - interface_result)))
+ tf_logging.info(
+ "data_type: %r, use_gpu: %r, grouped_conv: %r, max diff = %f",
+ data_type, use_gpu, grouped_conv,
+ np.amax(np.absolute(native_result - interface_result)))
self.assertArrayNear(
np.ravel(native_result), np.ravel(interface_result), tolerance)
self.assertShapeEqual(native_result, conv_native)
def testDepthwiseConv2D(self):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
- print("Testing DepthwiseConv2D,", index, "th config:", input_size, "*",
- filter_size, "stride:", stride, "padding:", padding)
+ tf_logging.info(
+ "Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
+ "%s", index, input_size, filter_size, stride, padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ tf_logging.info("Testing without grouped_conv")
self._VerifyValues(
input_size, filter_size, stride, padding, data_type, use_gpu=True)
+ tf_logging.info("Testing with grouped_conv")
+ self._VerifyValues(
+ input_size,
+ filter_size,
+ stride,
+ padding,
+ data_type,
+ use_gpu=True,
+ grouped_conv=True)
def testDepthwiseConv2DFormat(self):
if not test.is_gpu_available():
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
- print("Testing DepthwiseConv2DFormat,", index, "th config:", input_size,
- "*", filter_size, "stride:", stride, "padding:", padding)
+ tf_logging.info(
+ "Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
+ "padding: %s", index, input_size, filter_size, stride, padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._VerifyValues(
input_size,
conv = nn_ops.depthwise_conv2d_native(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
value = sess.run(conv)
- print("value = ", value)
+ tf_logging.info("value = %r", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
expected=expected_output,
use_gpu=True)
- # Gradient checkers.This tests depthwise gradient computations for both
+ # Gradient checkers. This tests depthwise gradient computations for both
# BackpropFilter and BackpropInput by comparing gradients computed by the
# depthwise gradient ops with the gradients computed numerically (details can
# be found in the compute_gradient_error().
data_type,
test_input,
use_gpu,
+ grouped_conv=False,
data_format="NHWC"):
input_size = 1
for x in input_shape:
filter_size *= x
input_data = [x * 1.0 / input_size for x in range(0, input_size)]
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
- with self.test_session(use_gpu=use_gpu):
- if data_type == dtypes.float16:
- tolerance = 0.002
- elif data_type == dtypes.float32:
- tolerance = 0.002
- else:
- self.assertEqual(data_type, dtypes.float64)
- tolerance = 1e-8
+ ops.reset_default_graph()
+ graph = ops.get_default_graph()
+ with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ tolerance = {
+ dtypes.float16: 2e-0,
+ dtypes.float32: 5e-4,
+ dtypes.float64: 1e-12,
+ }[data_type]
input_tensor = constant_op.constant(
input_data, shape=input_shape, dtype=data_type, name="input")
]
strides = [1, 1, stride, stride]
- depthwise_conv2d = nn_ops.depthwise_conv2d_native(
- native_input,
- filter_tensor,
- strides,
- padding,
- data_format=data_format,
- name="depthwise_conv2d")
+ with sess.graph._kernel_label_map({
+ "DepthwiseConv2dNative": "cudnn_grouped_convolution",
+ "DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution",
+ "DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution",
+ } if grouped_conv else {}):
+ depthwise_conv2d = nn_ops.depthwise_conv2d_native(
+ native_input,
+ filter_tensor,
+ strides,
+ padding,
+ data_format=data_format,
+ name="depthwise_conv2d")
self.assertEqual(output_shape, depthwise_conv2d.get_shape())
- if test_input:
- err = gradient_checker.compute_gradient_error(
- native_input, input_shape, depthwise_conv2d, output_shape)
- else:
- err = gradient_checker.compute_gradient_error(filter_tensor,
- filter_shape,
- depthwise_conv2d,
- output_shape)
- print("data_type:", data_type, "use_gpu:", use_gpu, ", error = ", err)
+
+ try:
+ if test_input:
+ err = gradient_checker.compute_gradient_error(
+ native_input, input_shape, depthwise_conv2d, output_shape)
+ else:
+ err = gradient_checker.compute_gradient_error(
+ filter_tensor, filter_shape, depthwise_conv2d, output_shape)
+ except errors.InvalidArgumentError as e:
+ # Grouped convolution kernel is only registered for cuDNN 7. Silently
+ # return when we are running on an earlier version or without GPU.
+ if grouped_conv and e.message.startswith(
+ "No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
+ tf_logging.warn("Skipping grouped convolution test")
+ return
+ raise e
+
+ tf_logging.info(
+ "data_type: %r, use_gpu: %r, grouped_conv: %r, error = %f", data_type,
+ use_gpu, grouped_conv, err)
self.assertLess(err, tolerance)
def testDepthwiseConv2DInputGrad(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
- print("Testing DepthwiseConv2DInputGrad,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
- # Note: float16 test for DepthwiseConv2DInputGrad is not enabled,
- # calculations are not very precise.
- for data_type in [dtypes.float32, dtypes.float64]:
+ tf_logging.info(
+ "Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
+ "padding: %s", index, input_size, filter_size, stride, padding)
+ for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
data_type,
test_input=True,
use_gpu=True)
+ self._ConstructAndTestGradient(
+ input_size,
+ filter_size,
+ output_size,
+ stride,
+ padding,
+ data_type,
+ test_input=True,
+ use_gpu=True,
+ grouped_conv=True)
def testDepthwiseConv2DInputGradFormat(self):
if not test.is_gpu_available():
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
- print("Testing DepthwiseConv2DInputGradFormat,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
- # Note: float16 test for DepthwiseConv2DInputGradFormat is not enabled,
- # calculations are not very precise.
- for data_type in [dtypes.float32, dtypes.float64]:
+ tf_logging.info(
+ "Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
+ "stride: %d, padding: %s", index, input_size, filter_size, stride,
+ padding)
+ for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
def testDepthwiseConv2DFilterGrad(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
- print("Testing DepthwiseConv2DFilterGrad,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
- # Note: float16 test for DepthwiseConv2DFilterGrad is not enabled,
- # calculations are not very precise.
- for data_type in [dtypes.float32, dtypes.float64]:
+ tf_logging.info(
+ "Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
+ "%d, padding: %s", index, input_size, filter_size, stride, padding)
+ for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
- print("Testing DepthwiseConv2DFilterGradFormat,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
- # Note: float16 test for DepthwiseConv2DFilterGradFormat is not enabled,
- # calculations are not very precise.
- for data_type in [dtypes.float32, dtypes.float64]:
+ tf_logging.info(
+ "Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
+ "stride: %d, padding: %s", index, input_size, filter_size, stride,
+ padding)
+ for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
def testDepthwiseConv2DInputGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
- print("Testing DepthwiseConv2DInputGradCompare,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
+ tf_logging.info(
+ "Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
+ "stride: %d, padding: %s", index, input_size, filter_size, stride,
+ padding)
self._CompareBackpropInputFloat(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropInputDouble(input_size, filter_size, output_size,
def testDepthwiseConv2DFilterGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
- print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
+ tf_logging.info(
+ "Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
+ "stride: %d, padding: %s", index, input_size, filter_size, stride,
+ padding)
self._CompareBackpropFilterFloat(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropFilterDouble(input_size, filter_size, output_size,
#if CUDNN_VERSION >= 7000
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionMathType) \
- __macro(cudnnSetRNNMatrixMathType)
+ __macro(cudnnSetRNNMatrixMathType) \
+ __macro(cudnnSetConvolutionGroupCount) \
+ __macro(cudnnGetConvolutionGroupCount)
// clang-format on
CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP)
// NOTE(benbarsdell): This only applies if tensor op math is enabled
// and algo selection is set to Default.
this->set_use_tensor_op_math(true);
+
+#if CUDNN_MAJOR >= 7
+ VLOG(2) << "Requesting grouped convolution: "
+ << convolution_descriptor.group_count();
+ status = wrap::cudnnSetConvolutionGroupCount(
+ parent_, handle_, convolution_descriptor.group_count());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not set cudnn convolution group count: "
+ << ToString(status);
+ }
+#else
+ CHECK_EQ(convolution_descriptor.group_count(), 1)
+ << "Requested grouped convolution for cuDNN version < 7";
+#endif
}
void set_use_tensor_op_math(bool use_tensor_op_math) {
filter_strides_(ndims, 1),
dilation_rates_(ndims, 1),
pad_alignment_(PadAlignment::kDefault),
+ group_count_(1),
ndims_(ndims) {}
ConvolutionDescriptor::ConvolutionDescriptor()
pad_alignment_ = pad_alignment;
return *this;
}
+ ConvolutionDescriptor& set_group_count(int group_count) {
+ group_count_ = group_count;
+ return *this;
+ }
int64 zero_padding_height() const {
return GetDim(zero_padding_, DimIndex::Y);
}
int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); }
int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); }
PadAlignment pad_alignment() const { return pad_alignment_; }
+ int group_count() const { return group_count_; }
int ndims() const { return ndims_; }
std::vector<int64> strides() const { return filter_strides_; }
std::vector<int64> filter_strides_;
std::vector<int64> dilation_rates_;
PadAlignment pad_alignment_;
+ int group_count_;
int ndims_;
// TODO(leary) cudnn provides these fields, but need to characterize what
// their effect is -- they may be boolean rather than integral.