Don't ever use cuDNN to perform depthwise convolutions on CPU.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 15 May 2018 20:38:34 +0000 (13:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 15 May 2018 20:40:56 +0000 (13:40 -0700)
PiperOrigin-RevId: 196721302

tensorflow/core/kernels/depthwise_conv_grad_op.cc
tensorflow/core/kernels/depthwise_conv_op.cc
tensorflow/python/kernel_tests/depthwise_conv_op_test.py

index 42a4832..da3bdb4 100644 (file)
@@ -564,7 +564,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
 
     // For in_depth == 1 and grouped convolutions.
-    use_cudnn_ = CanUseCudnn();
+    use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
     cudnn_use_autotune_ = CudnnUseAutotune();
     use_cudnn_grouped_conv_ = false;
     dtype_ = DataTypeToEnum<T>::value;
@@ -1037,7 +1037,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
 
     // For in_depth == 1 and grouped convolutions.
-    use_cudnn_ = CanUseCudnn();
+    use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
     cudnn_use_autotune_ = CudnnUseAutotune();
     use_cudnn_grouped_conv_ = false;
 
index d5f4a68..f0902fd 100644 (file)
@@ -290,7 +290,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
 
     // For in_depth == 1 and grouped convolutions.
-    use_cudnn_ = CanUseCudnn();
+    use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
     cudnn_use_autotune_ = CudnnUseAutotune();
     use_cudnn_grouped_conv_ = false;
     dtype_ = DataTypeToEnum<T>::value;
index 659dc04..5e223b1 100644 (file)
@@ -355,7 +355,7 @@ class DepthwiseConv2DTest(test.TestCase):
     graph = ops.get_default_graph()
     with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
       tolerance = {
-          dtypes.float16: 2e-0,
+          dtypes.float16: 4e-0,
           dtypes.float32: 5e-4,
           dtypes.float64: 1e-12,
       }[data_type]