return;
}
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ TensorShape collapsed_input_shape;
+ TensorShape collapsed_output_shape;
+ Tensor collapsed_paddings;
+ if (fixed_dims > 1 &&
+ CollapseAdjacentNonPaddedDimensions(
+ in0.shape(), in1, output_shape, &collapsed_input_shape,
+ &collapsed_paddings, &collapsed_output_shape)) {
+ Tensor collapsed_input;
+ CHECK(collapsed_input.CopyFrom(in0, collapsed_input_shape));
+ Tensor collapsed_output;
+ OP_REQUIRES_OK(context, context->allocate_temp(collapsed_input.dtype(),
+ collapsed_output_shape,
+ &collapsed_output));
+ const Tensor& collapsed_paddings_ref = collapsed_paddings;
+ typename TTypes<Tpadding>::ConstMatrix collapsed_paddings_matrix =
+ collapsed_paddings_ref.matrix<Tpadding>();
+ OperateWithVariableRank(context, collapsed_input_shape.dims(),
+ collapsed_input, collapsed_paddings_matrix,
+ pad_value, &collapsed_output);
+
+ Tensor output;
+ CHECK(output.CopyFrom(collapsed_output, output_shape));
+ context->set_output(0, output);
+ } else {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape, &output));
+ OperateWithVariableRank(context, fixed_dims, in0, paddings, pad_value,
+ output);
+ }
+ }
+
+ private:
+ // Collapses adjacent dimensions that are not padded to one dimension for
+ // speed. Returns true if any two dimensions are collapsed. For example,
+ //
+ // Pad(input_shape=[8, 28, 28, 3],
+ // paddings=[[0, 0], [0, 0], [0, 0], [0, 1]]
+ // is equivalent to
+ // Pad(input_shape=[6272, 3],
+ // paddings=[[0, 0], [0, 1]])
+ //
+ // input_shape: the original input shape.
+ // paddings_as_tensor: the original paddings.
+ // output_shape: the original output shape.
+ // collapsed_input_shape: the input shape after collapsing.
+ // collapsed_paddings_as_tensor: the paddings after collapsing.
+ // collapsed_output_shape: the output shape after collapsing.
+ static bool CollapseAdjacentNonPaddedDimensions(
+ const TensorShape& input_shape, const Tensor& paddings_as_tensor,
+ const TensorShape& output_shape, TensorShape* collapsed_input_shape,
+ Tensor* collapsed_paddings_as_tensor,
+ TensorShape* collapsed_output_shape) {
+ bool collapsed = false;
+ typename TTypes<Tpadding>::ConstMatrix paddings =
+ paddings_as_tensor.matrix<Tpadding>();
+ std::vector<std::pair<int, int>> collapsed_paddings;
+ int i = 0;
+ while (i < paddings.dimension(0)) {
+ if (paddings(i, 0) != 0 || paddings(i, 1) != 0) {
+ // If padded, copy the original dimension over.
+ collapsed_input_shape->InsertDim(collapsed_input_shape->dims(),
+ input_shape.dim_size(i));
+ collapsed_output_shape->InsertDim(collapsed_output_shape->dims(),
+ output_shape.dim_size(i));
+ collapsed_paddings.push_back({paddings(i, 0), paddings(i, 1)});
+ ++i;
+ } else {
+ // If not padded, find the next dimension that is padded and collapse
+ // all dimensions in between to one dimension.
+ int64 collapsed_input_dim_size = input_shape.dim_size(i);
+ int64 collapsed_output_dim_size = output_shape.dim_size(i);
+ ++i;
+ while (i < paddings.dimension(0) && paddings(i, 0) == 0 &&
+ paddings(i, 1) == 0) {
+ collapsed = true;
+ collapsed_input_dim_size *= input_shape.dim_size(i);
+ collapsed_output_dim_size *= output_shape.dim_size(i);
+ ++i;
+ }
+ collapsed_input_shape->InsertDim(collapsed_input_shape->dims(),
+ collapsed_input_dim_size);
+ collapsed_output_shape->InsertDim(collapsed_output_shape->dims(),
+ collapsed_output_dim_size);
+ collapsed_paddings.push_back({0, 0});
+ }
+ }
+
+ // Copy collapsed_paddings to collapsed_paddings_as_tensor.
+ *collapsed_paddings_as_tensor =
+ Tensor(paddings_as_tensor.dtype(),
+ TensorShape({static_cast<int64>(collapsed_paddings.size()), 2}));
+ auto collapsed_paddings_as_matrix =
+ collapsed_paddings_as_tensor->matrix<Tpadding>();
+ for (size_t i = 0; i < collapsed_paddings.size(); ++i) {
+ collapsed_paddings_as_matrix(i, 0) = collapsed_paddings[i].first;
+ collapsed_paddings_as_matrix(i, 1) = collapsed_paddings[i].second;
+ }
+ return collapsed;
+ }
+
+ void OperateWithVariableRank(OpKernelContext* context, int fixed_dims,
+ const Tensor& input,
+ typename TTypes<Tpadding>::ConstMatrix paddings,
+ T pad_value, Tensor* output) {
// Invoke the dims-specific implementation.
switch (fixed_dims) {
case 0:
- Operate<0>(context, in0.tensor<T, 0>(), paddings, pad_value, output);
+ Operate<0>(context, input.tensor<T, 0>(), paddings, pad_value, output);
break;
case 1:
// TODO(irving): Once Pad doesn't need a scalar special case,
// change flat to tensor. That is, once !allow_legacy_scalars().
- Operate<1>(context, in0.flat<T>(), paddings, pad_value, output);
+ Operate<1>(context, input.flat<T>(), paddings, pad_value, output);
break;
case 2:
- Operate<2>(context, in0.tensor<T, 2>(), paddings, pad_value, output);
+ Operate<2>(context, input.tensor<T, 2>(), paddings, pad_value, output);
break;
case 3:
- Operate<3>(context, in0.tensor<T, 3>(), paddings, pad_value, output);
+ Operate<3>(context, input.tensor<T, 3>(), paddings, pad_value, output);
break;
case 4:
- Operate<4>(context, in0.tensor<T, 4>(), paddings, pad_value, output);
+ Operate<4>(context, input.tensor<T, 4>(), paddings, pad_value, output);
break;
case 5:
- Operate<5>(context, in0.tensor<T, 5>(), paddings, pad_value, output);
+ Operate<5>(context, input.tensor<T, 5>(), paddings, pad_value, output);
break;
case 6:
- Operate<6>(context, in0.tensor<T, 6>(), paddings, pad_value, output);
+ Operate<6>(context, input.tensor<T, 6>(), paddings, pad_value, output);
break;
default:
OP_REQUIRES(context, false,
errors::InvalidArgument("Only ranks up to 6 supported: ",
- in0.shape().DebugString()));
+ input.shape().DebugString()));
}
}
- private:
template <int Dims>
void Operate(OpKernelContext* context,
typename TTypes<T, Dims>::ConstTensor input,
self.assertAllEqual(inp, out)
self.assertShapeEqual(inp, tf_val)
+ def testCollapseAdjacentNonPaddedDimensions(self):
+ # pyformat: disable
+ for paddings_value in [[[0, 0], [0, 0], [0, 0], [0, 1]],
+ [[0, 0], [2, 3], [0, 0], [0, 0]],
+ [[0, 0], [0, 0], [0, 0], [0, 0]]]:
+ # pyformat: enable
+ inp = constant_op.constant(1.0, shape=[8, 28, 28, 3])
+ paddings = constant_op.constant(paddings_value, dtype=dtypes.int32)
+ padded = array_ops.pad(inp, paddings)
+ middle = array_ops.slice(padded, [row[0] for row in paddings_value],
+ [dim.value for dim in inp.shape.dims])
+ left = array_ops.slice(padded, [0, 0, 0, 0],
+ [row[0] for row in paddings_value])
+ right = array_ops.slice(
+ padded,
+ [paddings_value[i][0] + inp.shape.dims[i].value for i in range(4)],
+ [-1, -1, -1, -1])
+ with self.test_session(use_gpu=True):
+ self.assertAllEqual(inp.eval(), middle.eval())
+ self.assertAllEqual(
+ np.zeros([row[0] for row in paddings_value]), left.eval())
+ self.assertAllEqual(
+ np.zeros([row[1] for row in paddings_value]), right.eval())
+
+
if __name__ == "__main__":
test.main()