Transposes are can be merged into reshapes when the ordering of non-one
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 18:24:22 +0000 (11:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 18:28:22 +0000 (11:28 -0700)
dimensions remains unchanged.

PiperOrigin-RevId: 188751074

tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc

index c2b1660..5a36a90 100644 (file)
@@ -21,6 +21,33 @@ limitations under the License.
 
 namespace toco {
 
+namespace {
+
+bool TransposeAffectsMemoryOrder(std::vector<int> perm,
+                                 std::vector<int> in_shape) {
+  CHECK_EQ(perm.size(), in_shape.size());
+  // See what the ordering of the non-unary columns are before and after
+  // transpose permutation. If the major indices stay in the same order (not
+  // just the shape) then the flat buffer representation shouldn't change.
+  std::vector<int> old_major_index_ordering;
+  std::vector<int> new_major_index_ordering;
+  for (int i = 0; i < in_shape.size(); i++) {
+    if (in_shape[i] != 1) {
+      old_major_index_ordering.push_back(i);
+    }
+
+    if (in_shape[perm[i]] != 1) {
+      new_major_index_ordering.push_back(perm[i]);
+    }
+  }
+
+  CHECK_EQ(new_major_index_ordering.size(), old_major_index_ordering.size());
+
+  return old_major_index_ordering != new_major_index_ordering;
+}
+
+}  // namespace
+
 bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
   auto transpose_it = model->operators.begin() + op_index;
   if (transpose_it->get()->type != OperatorType::kTranspose) {
@@ -29,23 +56,26 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
   TransposeOperator* transpose_op =
       static_cast<TransposeOperator*>(transpose_it->get());
 
+  const auto& input_array = model->GetArray(transpose_op->inputs[0]);
   const auto& output_array = model->GetArray(transpose_op->outputs[0]);
-  if (!output_array.has_shape()) {
+  if (!input_array.has_shape() || !output_array.has_shape()) {
     // Yield until PropagateFixedSizes has been run on this op.
     return false;
   }
   // Note: We can assume we have error checked inputs in PropagateFixedSizes.
 
-  // This transpose is trivial if we only have one non-unitary dimension.
-  std::vector<int> const& dims = output_array.shape().dims();
-  unsigned non_unitary_axis_count = 0;
-  for (int i = 0; i < dims.size(); i++) {
-    if (dims[i] != 1) {
-      non_unitary_axis_count++;
-    }
+  // Check that the permutation has propogated.
+  std::vector<int> const& perm = transpose_op->perm;
+  if (perm.empty()) {
+    return false;
   }
-  if (non_unitary_axis_count > 1) {
-    // Transpose is not trivial
+
+  // This transpose is trivial if non-unitary dimensions remain in the same
+  // order.
+  std::vector<int> const& input_dims = input_array.shape().dims();
+  std::vector<int> const& output_dims = output_array.shape().dims();
+
+  if (TransposeAffectsMemoryOrder(perm, input_dims)) {
     return false;
   }
 
@@ -61,11 +91,11 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
   string shape_array_name = toco::AvailableArrayName(*model, perm_array_name);
   Array& shape_array = model->GetOrCreateArray(shape_array_name);
   *(shape_array.mutable_shape()->mutable_dims()) = {
-      1, static_cast<int>(dims.size())};
+      1, static_cast<int>(output_dims.size())};
   reshape_op->inputs.push_back(shape_array_name);
   shape_array.data_type = ArrayDataType::kInt32;
   auto& shape_buffer = shape_array.GetMutableBuffer<ArrayDataType::kInt32>();
-  shape_buffer.data = dims;
+  shape_buffer.data = output_dims;
 
   // Delete perm array if unused
   if (IsDiscardableArray(*model, perm_array_name) &&