From 179795c0067f05abe54904797288efebf6958b35 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 5 Feb 2018 16:32:38 -0800 Subject: [PATCH] Support negative axis in concatenation PiperOrigin-RevId: 184605786 --- tensorflow/contrib/lite/kernels/concatenation.cc | 6 ++++-- tensorflow/contrib/lite/kernels/concatenation_test.cc | 18 +++++++++++++++++- tensorflow/contrib/lite/testing/generate_examples.py | 4 +++- .../graph_transformations/propagate_fixed_sizes.cc | 11 +++++++---- 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 9e7a123..7ff9075 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -49,6 +49,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // dimensions except 'axis' must be equal. TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; TfLiteType input_type = t0->type; + if (axis < 0) axis += t0->dims->size; TF_LITE_ENSURE(context, axis >= 0); TF_LITE_ENSURE(context, axis < t0->dims->size); @@ -131,8 +132,9 @@ template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - + int axis = params->axis; TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + if (axis < 0) axis += output->dims->size; // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should // allocate and populate these during Prepare(). @@ -141,7 +143,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { #define TF_LITE_CONCATENATION(type, scalar) \ VectorOfInputs all_inputs(*context, *node->inputs); \ type::Concatenation( \ - RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + RemapDim(NumDimensions(output), axis), all_inputs.data(), \ all_inputs.dims(), node->inputs->size, GetTensorData(output), \ GetTensorDims(output)) diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc index 499856a..ba1ffc5 100644 --- a/tensorflow/contrib/lite/kernels/concatenation_test.cc +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -94,7 +94,7 @@ TEST(ConcatenationOpTest, TwoDimensionalOneInput) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(ConcatenationOpTest, TwoInputsTwoAxis) { +TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxes) { // We will concatenate two tensors along different dimensions. auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; @@ -107,6 +107,14 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxis) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + ConcatenationOpModel m0_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-2, + /*num_inputs=*/2); + m0_negative.SetInput(0, tensor0); + m0_negative.SetInput(1, tensor1); + m0_negative.Invoke(); + EXPECT_THAT(m0_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1, /*num_inputs=*/2); m1.SetInput(0, tensor0); @@ -114,6 +122,14 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxis) { m1.Invoke(); EXPECT_THAT(m1.GetOutput(), ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); + + ConcatenationOpModel m1_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-1, + /*num_inputs=*/2); + m1_negative.SetInput(0, tensor0); + m1_negative.SetInput(1, tensor1); + m1_negative.Invoke(); + EXPECT_THAT(m1_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } TEST(ConcatenationOpTest, FourInputs) { diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index b2227a7..6264daa 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1001,13 +1001,15 @@ def make_concatenation_tests(zip_path): test_parameters = [{ "base_shape": [[1, 3, 4, 3], [3, 4]], "num_tensors": [1, 2, 3, 4, 5, 6], - "axis": [0, 1, 2, 3], + "axis": [0, 1, 2, 3, -3, -2, -1], }] def get_shape(parameters, delta): """Return a tweaked version of 'base_shape'.""" axis = parameters["axis"] shape = parameters["base_shape"][:] + if axis < 0: + axis += len(shape) if axis < len(shape): shape[axis] += delta return shape diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 7f26884..fa7e70d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -546,6 +546,9 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { // Use 0 input as basis for output dimensions. const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); + // Negative axis means the count starts at the back of the dims(). + int axis = op->axis; + if (axis < 0) axis += first_input_array.shape().dims().size(); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; @@ -558,14 +561,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { CHECK_EQ(input_array.shape().dimensions_count(), output_array.shape().dimensions_count()); const std::vector& input_dims = input_array.shape().dims(); - CHECK_LT(op->axis, input_dims.size()); - concat_size += input_dims[op->axis]; + CHECK_LT(axis, input_dims.size()); + concat_size += input_dims[axis]; } // Write out the concat_size on the output array shape. auto& output_shape = *output_array.mutable_shape(); auto& output_dims = *output_shape.mutable_dims(); - CHECK_LT(op->axis, output_shape.dimensions_count()); - output_dims[op->axis] = concat_size; + CHECK_LT(axis, output_shape.dimensions_count()); + output_dims[axis] = concat_size; } void ProcessRangeOperator(Model* model, RangeOperator* op) { -- 2.7.4