Activates Eigen path for CPU implementation of atrous/dilated convolution (only forwa...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 17 Feb 2018 01:56:36 +0000 (17:56 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 17 Feb 2018 02:03:23 +0000 (18:03 -0800)
PiperOrigin-RevId: 186071285

tensorflow/core/kernels/conv_2d.h
tensorflow/core/kernels/conv_grad_filter_ops.cc
tensorflow/core/kernels/conv_grad_input_ops.cc
tensorflow/core/kernels/conv_ops.cc
tensorflow/python/kernel_tests/conv_ops_test.py

index 2142207b0d89a4b2f02c7f7b5d320c3b4b48462c..6949e5b5fd85f399473095f26314e9d58fa65464 100644 (file)
@@ -54,10 +54,12 @@ struct InflatePadAndShuffle {
 template <typename Device, typename Input, typename Filter, typename Output>
 void SpatialConvolutionFunc(const Device& d, Output output, Input input,
                             Filter filter, int row_stride, int col_stride,
+                            int row_dilation, int col_dilation,
                             const Eigen::PaddingType& padding) {
   // Need to swap row/col when calling Eigen.
   output.device(d) =
-      Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding);
+      Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
+                                col_dilation, row_dilation);
 }
 
 template <typename Device, typename T>
@@ -65,9 +67,10 @@ struct SpatialConvolution {
   void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
                   typename TTypes<T, 4>::ConstTensor input,
                   typename TTypes<T, 4>::ConstTensor filter, int row_stride,
-                  int col_stride, const Eigen::PaddingType& padding) {
+                  int col_stride, int row_dilation, int col_dilation,
+                  const Eigen::PaddingType& padding) {
     SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
-                           padding);
+                           row_dilation, col_dilation, padding);
   }
 };
 
@@ -77,11 +80,12 @@ struct SpatialConvolution<Device, Eigen::half> {
                   typename TTypes<Eigen::half, 4>::Tensor output,
                   typename TTypes<Eigen::half, 4>::ConstTensor input,
                   typename TTypes<Eigen::half, 4>::ConstTensor filter,
-                  int row_stride, int col_stride,
-                  const Eigen::PaddingType& padding) {
+                  int row_stride, int col_stride, int row_dilation,
+                  int col_dilation, const Eigen::PaddingType& padding) {
     output.device(d) =
         Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(),
-                                  col_stride, row_stride, padding)
+                                  col_stride, row_stride, padding, col_dilation,
+                                  row_dilation)
             .cast<Eigen::half>();
   }
 };
@@ -91,11 +95,13 @@ struct SpatialConvolutionBackwardInput {
   void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
                   typename TTypes<T, 4>::ConstTensor kernel,
                   typename TTypes<T, 4>::ConstTensor output_backward,
-                  int row_stride, int col_stride) {
+                  int row_stride, int col_stride, int row_dilation,
+                  int col_dilation) {
     // Need to swap row/col when calling Eigen.
     input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
         kernel, output_backward, input_backward.dimension(2),
-        input_backward.dimension(1), col_stride, row_stride);
+        input_backward.dimension(1), col_stride, row_stride, col_dilation,
+        row_dilation);
   }
 };
 
@@ -105,11 +111,13 @@ struct SpatialConvolutionBackwardFilter {
                   typename TTypes<T, 4>::Tensor kernel_backward,
                   typename TTypes<T, 4>::ConstTensor input,
                   typename TTypes<T, 4>::ConstTensor output_backward,
-                  int row_stride, int col_stride) {
+                  int row_stride, int col_stride, int row_dilation,
+                  int col_dilation) {
     // Need to swap row/col when calling Eigen.
     kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
         input, output_backward, kernel_backward.dimension(1),
-        kernel_backward.dimension(0), col_stride, row_stride);
+        kernel_backward.dimension(0), col_stride, row_stride, col_dilation,
+        row_dilation);
   }
 };
 
index 512bcc6c01bf3eb4aed92f90eebb060abda8a7fc..b8a5ae6a08e5c22fb5d69112b216b3c342b1bb1a 100644 (file)
@@ -101,7 +101,8 @@ struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
     const CPUDevice& d = ctx->eigen_device<CPUDevice>();
     functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
         d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(),
-        out_backprop.tensor<T, 4>(), row_stride, col_stride);
+        out_backprop.tensor<T, 4>(), row_stride, col_stride,
+        /*row_dilation=*/1, /*col_dilation=*/1);
   }
 };
 
index 0356ff4c0f4240ec806d1e337546cfce6771d92f..b87c7899c00ab79c60bdbd85ce28399d103d271d 100644 (file)
@@ -106,7 +106,8 @@ struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
     const CPUDevice& d = ctx->eigen_device<CPUDevice>();
     functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
         d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
-        out_backprop.tensor<T, 4>(), row_stride, col_stride);
+        out_backprop.tensor<T, 4>(), row_stride, col_stride,
+        /*row_dilation=*/1, /*col_dilation=*/1);
   }
 };
 
index dbddaf3dc640dcf2cad8f6ba7dd00aaa33a30e0c..2b81e14f95b1b3f4c04e02d50180a9adda9e51e0 100644 (file)
@@ -60,8 +60,8 @@ template <typename Device, typename T>
 struct LaunchGeneric {
   void operator()(OpKernelContext* ctx, const Tensor& input,
                   const Tensor& filter, int row_stride, int col_stride,
-                  const Padding& padding, Tensor* output,
-                  TensorFormat data_format) {
+                  int row_dilation, int col_dilation, const Padding& padding,
+                  Tensor* output, TensorFormat data_format) {
     CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
                                          "supports NHWC tensor format for now.";
     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
@@ -86,7 +86,8 @@ struct LaunchGeneric {
           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
           dim_pair);
     } else if (filter.dim_size(0) == input.dim_size(1) &&
-               filter.dim_size(1) == input.dim_size(2) && padding == VALID) {
+               filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
+               col_dilation == 1 && padding == VALID) {
       // If the input data and filter have the same height/width,
       // the 2D convolution is reduced to matrix multiplication.
       const int k =  // Length of reduction dimension.
@@ -103,7 +104,7 @@ struct LaunchGeneric {
       functor::SpatialConvolution<Device, T>()(
           ctx->eigen_device<Device>(), output->tensor<T, 4>(),
           input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
-          BrainPadding2EigenPadding(padding));
+          row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
     }
   }
 };
@@ -122,15 +123,9 @@ struct LaunchConv2DOp<CPUDevice, T> {
                                 "NHWC tensor format for now."));
       return;
     }
-    // TODO(yangzihao): Add the CPU implementation of dilated conv 2D.
-    if (row_dilation > 1 || col_dilation > 1) {
-      ctx->SetStatus(
-          errors::Unimplemented("Generic conv implementation only supports "
-                                "dilated rate of 1 for now."));
-      return;
-    }
     LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
-                                  padding, output, data_format);
+                                  row_dilation, col_dilation, padding, output,
+                                  data_format);
   }
 };
 
@@ -792,7 +787,8 @@ namespace functor {
       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,              \
       typename TTypes<T, 4>::ConstTensor input,                              \
       typename TTypes<T, 4>::ConstTensor filter, int row_stride,             \
-      int col_stride, const Eigen::PaddingType& padding);                    \
+      int col_stride, int row_dilation, int col_dilation,                    \
+      const Eigen::PaddingType& padding);                                    \
   extern template struct SpatialConvolution<GPUDevice, T>;                   \
   template <>                                                                \
   void MatMulConvFunctor<GPUDevice, T>::operator()(                          \
index edfb20d6a2b80cec930ddf696e8f0f69623a4de7..27857989167ecd11c33286d7bb6cb068edd12831 100644 (file)
@@ -302,25 +302,20 @@ class Conv2DTest(test.TestCase):
                                padding, dilations):
     expected_results = []
     computed_results = []
-    default_dilations = (dilations[0] == 1 and dilations[1] == 1)
     for data_format, use_gpu in GetTestConfigs():
-      # If any dilation rate is larger than 1, only do test on the GPU
-      # because we currently do not have a CPU implementation for arbitrary
-      # dilation rates.
-      if default_dilations or use_gpu:
-        expected, computed = self._ComputeReferenceDilatedConv(
-            tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
-            data_format, use_gpu)
-        expected_results.append(expected)
-        computed_results.append(computed)
-        tolerance = 1e-2 if use_gpu else 1e-5
-        expected_values = self.evaluate(expected_results)
-        computed_values = self.evaluate(computed_results)
-        for e_value, c_value in zip(expected_values, computed_values):
-          print("expected = ", e_value)
-          print("actual = ", c_value)
-          self.assertAllClose(
-              e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
+      expected, computed = self._ComputeReferenceDilatedConv(
+          tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
+          data_format, use_gpu)
+      expected_results.append(expected)
+      computed_results.append(computed)
+      tolerance = 1e-2 if use_gpu else 1e-5
+      expected_values = self.evaluate(expected_results)
+      computed_values = self.evaluate(computed_results)
+      for e_value, c_value in zip(expected_values, computed_values):
+        print("expected = ", e_value)
+        print("actual = ", c_value)
+        self.assertAllClose(
+            e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
 
   def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
                     expected):
@@ -365,13 +360,12 @@ class Conv2DTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Filter2x1Dilation(self):
-    if test.is_gpu_available(cuda_only=True):
-      self._VerifyDilatedConvValues(
-          tensor_in_sizes=[1, 4, 4, 1],
-          filter_in_sizes=[2, 2, 1, 1],
-          strides=[1, 1],
-          dilations=[2, 1],
-          padding="VALID")
+    self._VerifyDilatedConvValues(
+        tensor_in_sizes=[1, 4, 4, 1],
+        filter_in_sizes=[2, 2, 1, 1],
+        strides=[1, 1],
+        dilations=[2, 1],
+        padding="VALID")
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2DEmpty(self):
@@ -385,13 +379,12 @@ class Conv2DTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2DEmptyDilation(self):
-    if test.is_gpu_available(cuda_only=True):
-      self._VerifyDilatedConvValues(
-          tensor_in_sizes=[0, 2, 3, 3],
-          filter_in_sizes=[1, 1, 3, 3],
-          strides=[1, 1],
-          dilations=[2, 1],
-          padding="VALID")
+    self._VerifyDilatedConvValues(
+        tensor_in_sizes=[0, 2, 3, 3],
+        filter_in_sizes=[1, 1, 3, 3],
+        strides=[1, 1],
+        dilations=[2, 1],
+        padding="VALID")
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Filter(self):
@@ -406,13 +399,12 @@ class Conv2DTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2FilterDilation(self):
-    if test.is_gpu_available(cuda_only=True):
-      self._VerifyDilatedConvValues(
-          tensor_in_sizes=[1, 2, 3, 3],
-          filter_in_sizes=[2, 2, 3, 3],
-          strides=[1, 1],
-          dilations=[1, 2],
-          padding="VALID")
+    self._VerifyDilatedConvValues(
+        tensor_in_sizes=[1, 2, 3, 3],
+        filter_in_sizes=[2, 2, 3, 3],
+        strides=[1, 1],
+        dilations=[1, 2],
+        padding="VALID")
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D1x2Filter(self):
@@ -430,13 +422,12 @@ class Conv2DTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D1x2FilterDilation(self):
-    if test.is_gpu_available(cuda_only=True):
-      self._VerifyDilatedConvValues(
-          tensor_in_sizes=[1, 2, 3, 3],
-          filter_in_sizes=[1, 2, 3, 3],
-          strides=[1, 1],
-          dilations=[2, 1],
-          padding="VALID")
+    self._VerifyDilatedConvValues(
+        tensor_in_sizes=[1, 2, 3, 3],
+        filter_in_sizes=[1, 2, 3, 3],
+        strides=[1, 1],
+        dilations=[2, 1],
+        padding="VALID")
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2FilterStride2(self):
@@ -512,13 +503,12 @@ class Conv2DTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2DKernelSizeMatchesInputSizeDilation(self):
-    if test.is_gpu_available(cuda_only=True):
-      self._VerifyDilatedConvValues(
-          tensor_in_sizes=[1, 3, 3, 1],
-          filter_in_sizes=[2, 2, 1, 2],
-          strides=[1, 1],
-          dilations=[2, 2],
-          padding="VALID")
+    self._VerifyDilatedConvValues(
+        tensor_in_sizes=[1, 3, 3, 1],
+        filter_in_sizes=[2, 2, 1, 2],
+        strides=[1, 1],
+        dilations=[2, 2],
+        padding="VALID")
 
   # TODO(yzhwang): this currently fails.
   # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
@@ -1538,21 +1528,6 @@ class Conv2DTest(test.TestCase):
             use_gpu=False)
         self.evaluate(conv)
 
-  def testCPUConv2DDilatedUnimplemented(self):
-    with self.test_session(use_gpu=False):
-      with self.assertRaisesRegexp(errors_impl.UnimplementedError,
-                                   "dilated rate of 1 for now"):
-        conv = self._SetupValuesForDevice(
-            tensor_in_sizes=[1, 4, 4, 1],
-            filter_in_sizes=[2, 2, 1, 1],
-            dilations=[2, 1],
-            strides=[1, 1],
-            padding="VALID",
-            data_format="NHWC",
-            dtype=dtypes.float32,
-            use_gpu=False)
-        self.evaluate(conv)
-
 
 class DepthwiseConv2DTest(test.TestCase):
 
@@ -1887,7 +1862,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding,
 def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding):
 
   def Test(self):
-    if test.is_gpu_available(cuda_only=True) and stride == 1:
+    if stride == 1:
       tf_logging.info("Testing InceptionFwd with dilations %s",
                       (input_size, filter_size, stride, padding))
       self._VerifyDilatedConvValues(