Collapse adjacent dimensions that have no paddings.
authorJingyue Wu <jingyue@google.com>
Fri, 9 Mar 2018 06:05:27 +0000 (22:05 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Mar 2018 06:09:13 +0000 (22:09 -0800)
For example,

tf.pad(<4D tensor>, [[0, 0], [0, 0], [0, 0], [0, 1]])

is equivalent to a 2D pad, which is faster.

PiperOrigin-RevId: 188440916

tensorflow/core/kernels/pad_op.cc
tensorflow/python/kernel_tests/pad_op_test.py

index 77c1808..ce79541 100644 (file)
@@ -104,42 +104,144 @@ class PadOp : public OpKernel {
       return;
     }
 
-    Tensor* output = nullptr;
-    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+    TensorShape collapsed_input_shape;
+    TensorShape collapsed_output_shape;
+    Tensor collapsed_paddings;
+    if (fixed_dims > 1 &&
+        CollapseAdjacentNonPaddedDimensions(
+            in0.shape(), in1, output_shape, &collapsed_input_shape,
+            &collapsed_paddings, &collapsed_output_shape)) {
+      Tensor collapsed_input;
+      CHECK(collapsed_input.CopyFrom(in0, collapsed_input_shape));
+      Tensor collapsed_output;
+      OP_REQUIRES_OK(context, context->allocate_temp(collapsed_input.dtype(),
+                                                     collapsed_output_shape,
+                                                     &collapsed_output));
+      const Tensor& collapsed_paddings_ref = collapsed_paddings;
+      typename TTypes<Tpadding>::ConstMatrix collapsed_paddings_matrix =
+          collapsed_paddings_ref.matrix<Tpadding>();
 
+      OperateWithVariableRank(context, collapsed_input_shape.dims(),
+                              collapsed_input, collapsed_paddings_matrix,
+                              pad_value, &collapsed_output);
+
+      Tensor output;
+      CHECK(output.CopyFrom(collapsed_output, output_shape));
+      context->set_output(0, output);
+    } else {
+      Tensor* output = nullptr;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(0, output_shape, &output));
+      OperateWithVariableRank(context, fixed_dims, in0, paddings, pad_value,
+                              output);
+    }
+  }
+
+ private:
+  // Collapses adjacent dimensions that are not padded to one dimension for
+  // speed. Returns true if any two dimensions are collapsed. For example,
+  //
+  //   Pad(input_shape=[8, 28, 28, 3],
+  //       paddings=[[0, 0], [0, 0], [0, 0], [0, 1]]
+  // is equivalent to
+  //   Pad(input_shape=[6272, 3],
+  //       paddings=[[0, 0], [0, 1]])
+  //
+  // input_shape: the original input shape.
+  // paddings_as_tensor: the original paddings.
+  // output_shape: the original output shape.
+  // collapsed_input_shape: the input shape after collapsing.
+  // collapsed_paddings_as_tensor: the paddings after collapsing.
+  // collapsed_output_shape: the output shape after collapsing.
+  static bool CollapseAdjacentNonPaddedDimensions(
+      const TensorShape& input_shape, const Tensor& paddings_as_tensor,
+      const TensorShape& output_shape, TensorShape* collapsed_input_shape,
+      Tensor* collapsed_paddings_as_tensor,
+      TensorShape* collapsed_output_shape) {
+    bool collapsed = false;
+    typename TTypes<Tpadding>::ConstMatrix paddings =
+        paddings_as_tensor.matrix<Tpadding>();
+    std::vector<std::pair<int, int>> collapsed_paddings;
+    int i = 0;
+    while (i < paddings.dimension(0)) {
+      if (paddings(i, 0) != 0 || paddings(i, 1) != 0) {
+        // If padded, copy the original dimension over.
+        collapsed_input_shape->InsertDim(collapsed_input_shape->dims(),
+                                         input_shape.dim_size(i));
+        collapsed_output_shape->InsertDim(collapsed_output_shape->dims(),
+                                          output_shape.dim_size(i));
+        collapsed_paddings.push_back({paddings(i, 0), paddings(i, 1)});
+        ++i;
+      } else {
+        // If not padded, find the next dimension that is padded and collapse
+        // all dimensions in between to one dimension.
+        int64 collapsed_input_dim_size = input_shape.dim_size(i);
+        int64 collapsed_output_dim_size = output_shape.dim_size(i);
+        ++i;
+        while (i < paddings.dimension(0) && paddings(i, 0) == 0 &&
+               paddings(i, 1) == 0) {
+          collapsed = true;
+          collapsed_input_dim_size *= input_shape.dim_size(i);
+          collapsed_output_dim_size *= output_shape.dim_size(i);
+          ++i;
+        }
+        collapsed_input_shape->InsertDim(collapsed_input_shape->dims(),
+                                         collapsed_input_dim_size);
+        collapsed_output_shape->InsertDim(collapsed_output_shape->dims(),
+                                          collapsed_output_dim_size);
+        collapsed_paddings.push_back({0, 0});
+      }
+    }
+
+    // Copy collapsed_paddings to collapsed_paddings_as_tensor.
+    *collapsed_paddings_as_tensor =
+        Tensor(paddings_as_tensor.dtype(),
+               TensorShape({static_cast<int64>(collapsed_paddings.size()), 2}));
+    auto collapsed_paddings_as_matrix =
+        collapsed_paddings_as_tensor->matrix<Tpadding>();
+    for (size_t i = 0; i < collapsed_paddings.size(); ++i) {
+      collapsed_paddings_as_matrix(i, 0) = collapsed_paddings[i].first;
+      collapsed_paddings_as_matrix(i, 1) = collapsed_paddings[i].second;
+    }
+    return collapsed;
+  }
+
+  void OperateWithVariableRank(OpKernelContext* context, int fixed_dims,
+                               const Tensor& input,
+                               typename TTypes<Tpadding>::ConstMatrix paddings,
+                               T pad_value, Tensor* output) {
     // Invoke the dims-specific implementation.
     switch (fixed_dims) {
       case 0:
-        Operate<0>(context, in0.tensor<T, 0>(), paddings, pad_value, output);
+        Operate<0>(context, input.tensor<T, 0>(), paddings, pad_value, output);
         break;
       case 1:
         // TODO(irving): Once Pad doesn't need a scalar special case,
         // change flat to tensor.  That is, once !allow_legacy_scalars().
-        Operate<1>(context, in0.flat<T>(), paddings, pad_value, output);
+        Operate<1>(context, input.flat<T>(), paddings, pad_value, output);
         break;
       case 2:
-        Operate<2>(context, in0.tensor<T, 2>(), paddings, pad_value, output);
+        Operate<2>(context, input.tensor<T, 2>(), paddings, pad_value, output);
         break;
       case 3:
-        Operate<3>(context, in0.tensor<T, 3>(), paddings, pad_value, output);
+        Operate<3>(context, input.tensor<T, 3>(), paddings, pad_value, output);
         break;
       case 4:
-        Operate<4>(context, in0.tensor<T, 4>(), paddings, pad_value, output);
+        Operate<4>(context, input.tensor<T, 4>(), paddings, pad_value, output);
         break;
       case 5:
-        Operate<5>(context, in0.tensor<T, 5>(), paddings, pad_value, output);
+        Operate<5>(context, input.tensor<T, 5>(), paddings, pad_value, output);
         break;
       case 6:
-        Operate<6>(context, in0.tensor<T, 6>(), paddings, pad_value, output);
+        Operate<6>(context, input.tensor<T, 6>(), paddings, pad_value, output);
         break;
       default:
         OP_REQUIRES(context, false,
                     errors::InvalidArgument("Only ranks up to 6 supported: ",
-                                            in0.shape().DebugString()));
+                                            input.shape().DebugString()));
     }
   }
 
- private:
   template <int Dims>
   void Operate(OpKernelContext* context,
                typename TTypes<T, Dims>::ConstTensor input,
index aaeb3b1..236aa4a 100644 (file)
@@ -336,5 +336,30 @@ class PadOpTest(test.TestCase):
       self.assertAllEqual(inp, out)
       self.assertShapeEqual(inp, tf_val)
 
+  def testCollapseAdjacentNonPaddedDimensions(self):
+    # pyformat: disable
+    for paddings_value in [[[0, 0], [0, 0], [0, 0], [0, 1]],
+                           [[0, 0], [2, 3], [0, 0], [0, 0]],
+                           [[0, 0], [0, 0], [0, 0], [0, 0]]]:
+      # pyformat: enable
+      inp = constant_op.constant(1.0, shape=[8, 28, 28, 3])
+      paddings = constant_op.constant(paddings_value, dtype=dtypes.int32)
+      padded = array_ops.pad(inp, paddings)
+      middle = array_ops.slice(padded, [row[0] for row in paddings_value],
+                               [dim.value for dim in inp.shape.dims])
+      left = array_ops.slice(padded, [0, 0, 0, 0],
+                             [row[0] for row in paddings_value])
+      right = array_ops.slice(
+          padded,
+          [paddings_value[i][0] + inp.shape.dims[i].value for i in range(4)],
+          [-1, -1, -1, -1])
+      with self.test_session(use_gpu=True):
+        self.assertAllEqual(inp.eval(), middle.eval())
+        self.assertAllEqual(
+            np.zeros([row[0] for row in paddings_value]), left.eval())
+        self.assertAllEqual(
+            np.zeros([row[1] for row in paddings_value]), right.eval())
+
+
 if __name__ == "__main__":
   test.main()