Three operators are added that can be used to decrease the number of
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Apr 2018 17:35:00 +0000 (10:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 3 Apr 2018 17:37:47 +0000 (10:37 -0700)
transpose/reshape occurences. These operators are:
- Swap elementwise
- Swap reshape-transpose
- Fuse transpose-reshape

Swap elementwise groups operations that operate on value, and move values. This
allows all movement operations (reshape, transpose, etc) to group at one
portion of the graph while value manipulations group on the other end.

Swap reshape/transpose finds cases where the reshape-transpose can be reorderd.
This allows the reshape-transpose-reshape case to be transformed to
transpose-reshape-reshape, which can then merge the reshape-reshape.

Finally the transpose-reshape merge allows cases where, when reshape maintains
the same dimensions and does not affect memory ordering, the transpose and
reshape can be merged into a single transpose operator.

PiperOrigin-RevId: 191462038

tensorflow/contrib/lite/toco/BUILD
tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc [new file with mode: 0644]
tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc [deleted file]
tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc [new file with mode: 0644]
tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc [new file with mode: 0644]
tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc
tensorflow/contrib/lite/toco/tooling_util.h

index 2dd689a..8a35fb9 100644 (file)
@@ -234,6 +234,7 @@ cc_library(
         "graph_transformations/identify_relu1.cc",
         "graph_transformations/lstm_utils.cc",
         "graph_transformations/make_initial_dequantize_operator.cc",
+        "graph_transformations/merge_reshape_into_preceding_transpose.cc",
         "graph_transformations/propagate_activation_function_into_constants.cc",
         "graph_transformations/propagate_array_data_types.cc",
         "graph_transformations/propagate_fixed_sizes.cc",
@@ -251,7 +252,8 @@ cc_library(
         "graph_transformations/remove_trivial_reshape.cc",
         "graph_transformations/remove_trivial_slice.cc",
         "graph_transformations/remove_unused_op.cc",
-        "graph_transformations/reorder_activation_functions.cc",
+        "graph_transformations/reorder_elementwise_unary.cc",
+        "graph_transformations/reorder_reshape_transpose.cc",
         "graph_transformations/resolve_batch_normalization.cc",
         "graph_transformations/resolve_batch_to_space_nd_attributes.cc",
         "graph_transformations/resolve_constant_binary.cc",
index 76ec02a..27c5044 100644 (file)
@@ -128,6 +128,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
 DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
 DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs)
 DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
+DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
 DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
 DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
 DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
@@ -152,7 +153,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator)
 DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
 DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
 DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax)
-DECLARE_GRAPH_TRANSFORMATION(ReorderActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary)
+DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose)
 DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
new file mode 100644 (file)
index 0000000..5065004
--- /dev/null
@@ -0,0 +1,190 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool OperatorReady(const Model& model, const Operator* op) {
+  if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) ||
+      !model.HasArray(op->outputs[0])) {
+    // Arrays are missing.
+    return false;
+  }
+
+  if (!model.GetArray(op->inputs[0]).has_shape() ||
+      !model.GetArray(op->outputs[0]).has_shape()) {
+    // Input and output needs the shape.
+    return false;
+  }
+
+  if (!model.GetArray(op->inputs[1]).buffer) {
+    // Buffer needs to be a constant.
+    return false;
+  }
+
+  return true;
+}
+
+// Returns whether the reshape could be a transpose.
+std::vector<int32> ReshapeToTranspose(const Model& model,
+                                      const TensorFlowReshapeOperator* op) {
+  CHECK(!op->shape.empty());
+  CHECK(model.HasArray(op->inputs[0]));
+  CHECK(model.HasArray(op->outputs[0]));
+
+  const auto& input_array = model.GetArray(op->inputs[0]);
+  const auto& output_array = model.GetArray(op->outputs[0]);
+
+  CHECK(input_array.has_shape());
+  CHECK(output_array.has_shape());
+
+  std::vector<int> in_shape = input_array.shape().dims();
+  std::vector<int> out_shape = output_array.shape().dims();
+
+  std::vector<int> one_indices;
+  std::vector<int> not_one_indices;
+
+  // Separate into one indices and not one indices.
+  for (int i = 0; i < in_shape.size(); i++) {
+    if (in_shape[i] == 1) {
+      one_indices.push_back(i);
+    } else {
+      not_one_indices.push_back(i);
+    }
+  }
+
+  // Reorder the vertices.
+  std::vector<int> perm;
+  perm.reserve(in_shape.size());
+  int one_index = 0;
+  int not_one_index = 0;
+  for (const auto val : out_shape) {
+    if (val == 1) {
+      perm.push_back(one_indices[one_index]);
+      one_index++;
+    } else {
+      perm.push_back(not_one_indices[not_one_index]);
+      not_one_index++;
+    }
+  }
+
+  return perm;
+}
+
+}  // namespace
+
+// When a transpose is fed into a reshape, it is possible for the two operators
+// to be merged if the reshape does not affect memory ordering and does not
+// affects the number of dimensions. This only occurs when only unary dimensions
+// are shifting position.
+bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
+                                             std::size_t op_index) {
+  auto it = model->operators.begin() + op_index;
+  auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
+      it->get(), OperatorType::kTensorFlowReshape);
+
+  if (reshape_op == nullptr) {
+    return false;
+  }
+
+  if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
+    return false;
+  }
+
+  const string intermediate_name = reshape_op->inputs[0];
+  const string output_name = reshape_op->outputs[0];
+
+  // Guarantee the input is only consume by the reshape.
+  if (CountOpsWithInput(*model, intermediate_name) != 1) {
+    return false;
+  }
+
+  // Check for the parent operator.
+  const auto& transpose_it = FindOpWithOutput(*model, intermediate_name);
+  if (transpose_it == model->operators.end()) {
+    return false;
+  }
+
+  // Find the parent operator and guarantee it is a transpose.
+  TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
+      transpose_it->get(), OperatorType::kTranspose);
+
+  if (transpose_op == nullptr) {
+    return false;
+  }
+
+  if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
+    return false;
+  }
+
+  if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
+                                      false /*allow_extra_unary_dimensions*/)) {
+    return false;
+  }
+
+  // Check that the intermediate is not an output array.
+  if (!IsDiscardableArray(*model, intermediate_name)) {
+    AddMessageF(
+        "Cannot fuse %s and %s as it would invalidate the transpose "
+        "output array.",
+        LogName(*transpose_op), LogName(*reshape_op));
+    return false;
+  }
+
+  AddMessageF("Merging operations %s and %s", LogName(*transpose_op),
+              LogName(*reshape_op));
+
+  // const auto& intermediate_array = model->GetArray(intermediate_name);
+  // const auto& output_array = model->GetArray(output_name);
+
+  auto merged_perm = ReshapeToTranspose(*model, reshape_op);
+
+  // Combine the permutations.
+  const auto& transpose_perm = transpose_op->perm;
+  for (int i = 0; i < merged_perm.size(); i++) {
+    merged_perm[i] = transpose_perm[merged_perm[i]];
+  }
+
+  // Remove the reshape as passthrough operation.
+  if (!RemoveTrivialPassthroughOp(this, model, op_index)) {
+    return false;
+  }
+
+  // Update transpose_op's constant buffer to contain the new permutation.
+  model->GetArray(transpose_op->inputs[1])
+      .GetMutableBuffer<ArrayDataType::kInt32>()
+      .data = merged_perm;
+  transpose_op->perm = merged_perm;
+
+  // transpose_ops's shape will likely has changed.
+  model->GetArray(transpose_op->outputs[0]).clear_shape();
+
+  return true;
+}
+
+}  // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc
deleted file mode 100644 (file)
index 9852c86..0000000
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/runtime/types.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace toco {
-
-bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) {
-  const auto ac_it = model->operators.begin() + op_index;
-  std::unique_ptr<Operator>& ac_op = *ac_it;
-  DCHECK(ac_op);
-
-  if (ac_op->type != OperatorType::kRelu6 &&
-      ac_op->type != OperatorType::kRelu1 &&
-      ac_op->type != OperatorType::kRelu) {
-    return false;
-  }
-
-  auto exchange_it = FindOpWithOutput(*model, ac_op->inputs[0]);
-  if (exchange_it == model->operators.end()) return false;
-  // Find the op producing the array passed to this activation function
-  std::unique_ptr<Operator>& exchange_op = *exchange_it;
-  DCHECK(exchange_op);
-
-  // Allow activation functions to move up over any operator that does not
-  // change the values.
-  switch (exchange_op->type) {
-    case OperatorType::kExpandDims:
-    case OperatorType::kSqueeze:
-    case OperatorType::kTensorFlowReshape:
-    case OperatorType::kTranspose:
-      break;
-    default:
-      return false;
-  }
-
-  DCHECK_EQ(exchange_op->outputs[0], ac_op->inputs[0]);
-  const auto exchange_op_input = exchange_op->inputs[0];
-  const auto intermediate_array = exchange_op->outputs[0];
-  const auto ac_op_output = ac_op->outputs[0];
-
-  int count_ops_consuming_output =
-      CountOpsWithInput(*model, intermediate_array);
-  DCHECK_GE(count_ops_consuming_output, 1);
-  if (count_ops_consuming_output > 1) {
-    AddMessageF(
-        "Not exchanging activation function with %s because it is consumed by "
-        "more than 1 other operator",
-        LogName(*exchange_op));
-    return false;
-  }
-
-  // If the ac_op was originally producing an output_array we can't trivially
-  // reorder as otherwise the output array name would change and break
-  // downstream assumptions. To work around that we perform some renaming below
-  // in that case at the cost of a bit more confusing array names in this rare
-  // case.
-  bool is_ac_op_output =
-      std::find(model->flags.output_arrays().begin(),
-                model->flags.output_arrays().end(),
-                ac_op_output) != model->flags.output_arrays().end();
-  if (is_ac_op_output) {
-    // To preserve the output array name of the activation function we need to
-    // create a temporary to use to pass between ac->ex.
-    //
-    // Original:
-    //  (a) -> EX -> (b) -> AC -> (c)
-    // Now:
-    //  (a) -> AC -> (c') -> EX -> (c)
-    AddMessageF(
-        "Exchanging activation function %s with %s but renaming to preserve "
-        "output array %s",
-        LogName(*ac_op), LogName(*exchange_op), ac_op->outputs[0]);
-
-    auto renamed_ac_op_output =
-        AvailableArrayName(*model, ac_op_output + "_exchange");
-    ac_op->inputs[0] = exchange_op_input;
-    ac_op->outputs[0] = renamed_ac_op_output;
-    model->EraseArray(exchange_op->outputs[0]);
-    exchange_op->inputs[0] = renamed_ac_op_output;
-    exchange_op->outputs[0] = ac_op_output;
-  } else {
-    // Simply swap the order and update consumers to use the exchange_op output
-    // array (b).
-    //
-    // Original:
-    //  (a) -> EX -> (b) -> AC -> (c)
-    // Now:
-    //  (a) -> AC -> (c) -> EX -> (b)
-    AddMessageF("Exchanging activation function %s with %s", LogName(*ac_op),
-                LogName(*exchange_op));
-
-    Operator* consumer = GetFirstOpWithInput(*model, ac_op_output);
-    while (consumer) {
-      for (int i = 0; i < consumer->inputs.size(); ++i) {
-        if (consumer->inputs[i] == ac_op_output) {
-          consumer->inputs[i] = intermediate_array;
-        }
-      }
-      consumer = GetFirstOpWithInput(*model, ac_op_output);
-    }
-    ac_op->inputs[0] = exchange_op_input;
-    exchange_op->inputs[0] = ac_op_output;
-  }
-
-  // Clear shapes; this will allow shape propagation to fix the sizes for us.
-  model->GetOrCreateArray(ac_op->outputs[0]).clear_shape();
-  model->GetOrCreateArray(exchange_op->outputs[0]).clear_shape();
-
-  // Finally, reorder operators.  Note that this only works when there are no
-  // other direct descendents of the exchange_op.
-  ac_op.swap(exchange_op);
-
-  return true;
-}
-
-}  // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
new file mode 100644 (file)
index 0000000..9f5b792
--- /dev/null
@@ -0,0 +1,153 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool IsElementwiseOperator(OperatorType optype) {
+  switch (optype) {
+    case OperatorType::kCast:
+    case OperatorType::kExp:
+    case OperatorType::kFloor:
+    case OperatorType::kNeg:
+    case OperatorType::kRelu:
+    case OperatorType::kRelu1:
+    case OperatorType::kRelu6:
+    case OperatorType::kTanh:
+    case OperatorType::kTensorFlowSqrt:
+    case OperatorType::kTensorFlowSquare:
+      return true;
+    default:
+      return false;
+  }
+}
+
+bool IsMoveOperator(OperatorType optype) {
+  switch (optype) {
+    case OperatorType::kDepthToSpace:
+    case OperatorType::kExpandDims:
+    case OperatorType::kSpaceToDepth:
+    case OperatorType::kSqueeze:
+    case OperatorType::kTensorFlowReshape:
+    case OperatorType::kTranspose:
+      return true;
+    default:
+      return false;
+  }
+}
+
+}  // namespace
+
+// Swap elementwise operators such that all value operators occur before all
+// element move operators, e.g. negation then transpose.
+bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
+  const auto element_op_it = model->operators.begin() + op_index;
+  std::unique_ptr<Operator>& element_op = *element_op_it;
+  if (!IsElementwiseOperator(element_op->type)) {
+    return false;
+  }
+
+  const string intermediate_name = element_op->inputs[0];
+  auto it = FindOpWithOutput(*model, intermediate_name);
+  if (it == model->operators.end()) {
+    AddMessageF("No preceding operator");
+    return false;
+  }
+
+  std::unique_ptr<Operator>& move_op = *it;
+  if (!IsMoveOperator(move_op->type)) {
+    AddMessageF("Preceding operator is not a move operator");
+    return false;
+  }
+
+  if (CountOpsWithInput(*model, intermediate_name) != 1) {
+    AddMessageF("Input %s used elsewhere", intermediate_name);
+    return false;
+  }
+
+  // Check that the intermediate is discardable.
+  if (!IsDiscardableArray(*model, intermediate_name)) {
+    AddMessageF(
+        "Cannot swap elementwise as it would invalidate %s which is "
+        "an output array.",
+        intermediate_name);
+    return false;
+  }
+
+  // op->inputs may change so we need to keep a value by copy.
+  const string input_name = move_op->inputs[0];
+  const string output_name = element_op->outputs[0];
+
+  AddMessageF("Swapping around operators with %s and %s", LogName(*element_op),
+              LogName(*move_op));
+
+  // If the output array is an exit node for the graph then we need to retain
+  // the name as an output node. This makes the naming scheme a little confusing
+  // but is required in this rare case.
+  if (!IsDiscardableArray(*model, output_name)) {
+    // The output name of the sequence needs to stay static, so create a new
+    // array new use for the intermediate.
+    const auto new_intermediate_name =
+        AvailableArrayName(*model, element_op->outputs[0] + "_reorder");
+    AddMessageF("Adding new array %s to preserve output array name %s",
+                new_intermediate_name, output_name);
+
+    element_op->inputs[0] = input_name;
+    element_op->outputs[0] = new_intermediate_name;
+    model->EraseArray(intermediate_name);
+    move_op->inputs[0] = new_intermediate_name;
+    move_op->outputs[0] = output_name;
+  } else {
+    // The intermediate array is now the output array.
+    for (int i = 0; i < model->operators.size(); i++) {
+      Operator* consumer = model->operators[i].get();
+      for (int j = 0; j < consumer->inputs.size(); j++) {
+        if (consumer->inputs[j] == output_name) {
+          consumer->inputs[j] = intermediate_name;
+        }
+      }
+    }
+
+    element_op->inputs[0] = input_name;
+    move_op->inputs[0] = output_name;
+  }
+
+  // Reset both arrays as shape, type, min/max, etc can all change because of
+  // the position swap.
+  model->EraseArray(element_op->outputs[0]);
+  model->EraseArray(move_op->outputs[0]);
+
+  // Reconstruct.
+  model->GetOrCreateArray(element_op->outputs[0]);
+  model->GetOrCreateArray(move_op->outputs[0]);
+
+  // Swap the order of the operators.
+  element_op.swap(move_op);
+
+  return true;
+}
+
+}  // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
new file mode 100644 (file)
index 0000000..9e7fe1b
--- /dev/null
@@ -0,0 +1,248 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool OperatorReady(const Model& model, const Operator* op) {
+  if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) ||
+      !model.HasArray(op->outputs[0])) {
+    return false;
+  }
+
+  if (!model.GetArray(op->inputs[0]).has_shape() ||
+      !model.GetArray(op->outputs[0]).has_shape()) {
+    // Input and output needs the shape.
+    return false;
+  }
+
+  if (!model.GetArray(op->inputs[1]).buffer) {
+    // Buffer needs to be a constant.
+    return false;
+  }
+
+  return true;
+}
+
+// Utility function to filter out a value.
+void Filter(std::vector<int>* vec, int value) {
+  vec->erase(std::remove(vec->begin(), vec->end(), value), vec->end());
+}
+
+// Computes a new permutation used to swap a reshape-transpose to a
+// transpose-reshape. In this case the permutation operates on the intermediate
+// shape.
+std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
+                                std::vector<int> intermediate_dims,
+                                std::vector<int> perm) {
+  // These are the major axis of the input.
+  std::vector<int> input_indices;
+  for (int i = 0; i < input_dims.size(); i++) {
+    if (input_dims[i] != 1) {
+      input_indices.push_back(i);
+    }
+  }
+
+  // This maps which indices of the input produced the intermediate indices for
+  // non-unary dimensions.
+  std::unordered_map<int, int> intermediate_to_input_indices_map;
+  for (int i = 0; i < intermediate_dims.size(); i++) {
+    if (intermediate_dims[i] != 1) {
+      intermediate_to_input_indices_map[i] =
+          input_indices[intermediate_to_input_indices_map.size()];
+    }
+  }
+
+  // Translate the transpose permutation to a new permutation starting with the
+  // major indices.
+  std::vector<int> new_perm;
+  new_perm.reserve(input_dims.size());
+  for (int i = 0; i < perm.size(); i++) {
+    if (intermediate_dims[perm[i]] == 1) continue;
+
+    new_perm.push_back(intermediate_to_input_indices_map[perm[i]]);
+  }
+
+  // Fill the rest of the transpose in with the ones.
+  for (int index = 0; index < input_dims.size(); index++) {
+    if (input_dims[index] == 1) {
+      new_perm.push_back(index);
+    }
+  }
+
+  CHECK_EQ(new_perm.size(), input_dims.size());
+  return new_perm;
+}
+
+}  // namespace
+
+// Swaps reshape-transpose to transpose-reshape whenever possible. This is
+// possible when the reshape does not affect memory ordering.
+bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
+  auto transpose_it = model->operators.begin() + op_index;
+
+  TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
+      transpose_it->get(), OperatorType::kTranspose);
+
+  if (transpose_op == nullptr) {
+    return false;
+  }
+
+  if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
+    // Wait for values to propagate.
+    return false;
+  }
+
+  // Find the operator that produces the transpose op.
+  auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]);
+  if (reshape_it == model->operators.end()) {
+    return false;
+  }
+
+  TensorFlowReshapeOperator* reshape_op =
+      ConvertOperator<TensorFlowReshapeOperator*>(
+          reshape_it->get(), OperatorType::kTensorFlowReshape);
+  if (reshape_op == nullptr) {
+    return false;
+  }
+
+  // Ignore if the reshape is uninitialized.
+  if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
+    return false;
+  }
+
+  // Need to copy to keep static if permutated.
+  const string input_name = reshape_op->inputs[0];
+  const string intermediate_name = reshape_op->outputs[0];
+  const string output_name = transpose_op->outputs[0];
+
+  // Intermediate should not be consumed by any other operators.
+  if (CountOpsWithInput(*model, intermediate_name) != 1) {
+    AddMessageF("Input %s used elsewhere", intermediate_name);
+    return false;
+  }
+
+  // Check that the intermediate is not an output array.
+  if (!IsDiscardableArray(*model, intermediate_name)) {
+    AddMessageF(
+        "Cannot reorder reshape-transpose as it would invalidate %s which is "
+        "an output array.",
+        intermediate_name);
+    return false;
+  }
+
+  // Get the arrays.
+  const auto& input_array = model->GetArray(input_name);
+  const auto& intermediate_array = model->GetArray(intermediate_name);
+  const auto& output_array = model->GetArray(output_name);
+
+  // Get the shapes of each array.
+  Shape input_shape = input_array.shape();
+  Shape intermediate_shape = intermediate_array.shape();
+  Shape output_shape = output_array.shape();
+
+  // Assign ids to non-unary indices.
+  std::vector<int> input_dims = input_shape.dims();
+  std::vector<int> intermediate_dims = intermediate_shape.dims();
+  std::vector<int> output_dims = output_shape.dims();
+
+  // If the reshape is equivalent to a transpose with fewer/more unary
+  // dimensions then it can be moved between the transpose.
+  if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
+                                      true /*allow_extra_unary_dims*/)) {
+    return false;
+  }
+
+  if (!IsDiscardableArray(*model, output_name)) {
+    // The output name of the sequence needs to stay static, so create a new
+    // array new use for the intermediate.
+    const auto new_intermediate_name =
+        AvailableArrayName(*model, transpose_op->outputs[0] + "_exchange");
+    AddMessageF("Adding new array %s to preserve output array name %s",
+                new_intermediate_name, transpose_op->outputs[0]);
+    transpose_op->inputs[0] = input_name;
+    transpose_op->outputs[0] = new_intermediate_name;
+    reshape_op->inputs[0] = new_intermediate_name;
+    reshape_op->outputs[0] = output_name;
+    model->EraseArray(intermediate_name);
+  } else {
+    // The intermediate array is now the output array.
+    for (int i = 0; i < model->operators.size(); i++) {
+      Operator* consumer = model->operators[i].get();
+      for (int j = 0; j < consumer->inputs.size(); j++) {
+        if (consumer->inputs[j] == output_name) {
+          consumer->inputs[j] = intermediate_name;
+        }
+      }
+    }
+
+    transpose_op->inputs[0] = input_name;
+    reshape_op->inputs[0] = output_name;
+  }
+
+  // If transposes constant buffer is used elsewhere, make a new copy.
+  if (CountOpsWithInput(*model, transpose_op->inputs[1]) != 1) {
+    transpose_op->inputs[1] =
+        AvailableArrayName(*model, transpose_op->inputs[1] + "_copy");
+  }
+
+  // Make the new transpose permutation.
+  const std::vector<int> new_perm =
+      ComputeNewPerm(input_dims, intermediate_dims, transpose_op->perm);
+  CHECK_EQ(input_dims.size(), new_perm.size());
+
+  auto& transpose_array = model->GetOrCreateArray(transpose_op->inputs[1]);
+  transpose_array.GetMutableBuffer<ArrayDataType::kInt32>().data = new_perm;
+  *(transpose_array.mutable_shape()->mutable_dims()) = {
+      static_cast<int>(new_perm.size())};
+  transpose_op->perm = new_perm;
+
+  // If the reshape's constant buffer is reused, create a new one.
+  if (CountOpsWithInput(*model, reshape_op->inputs[1]) != 1) {
+    reshape_op->inputs[1] =
+        AvailableArrayName(*model, reshape_op->inputs[1] + "_copy");
+  }
+
+  // We need to modify the reshape input array to target the new output size.
+  auto& reshape_array = model->GetOrCreateArray(reshape_op->inputs[1]);
+  reshape_array.GetMutableBuffer<ArrayDataType::kInt32>().data = output_dims;
+  *(reshape_array.mutable_shape()->mutable_dims()) = {
+      static_cast<int>(output_shape.dimensions_count())};
+  reshape_op->shape.clear();
+
+  AddMessageF("Swapping around operators between %s and %s", input_name,
+              output_name);
+
+  model->GetOrCreateArray(transpose_op->outputs[0]).clear_shape();
+  model->GetOrCreateArray(reshape_op->outputs[0]).clear_shape();
+
+  // Swap the order of the operators.
+  transpose_it->swap(*reshape_it);
+
+  return true;
+}
+
+}  // namespace toco
index f38203c..2a236d3 100644 (file)
@@ -60,6 +60,13 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
   string input_lhs = matmul_op->inputs[0];
   string input_rhs = transpose_op->outputs[0];
 
+  // Construct the new FullyConnectedOperator.
+  auto* fc_op = new FullyConnectedOperator;
+  fc_op->outputs = matmul_op->outputs;
+
+  // Insert the newly constructed FullyConnectedOperator.
+  model->operators.emplace(matmul_it, fc_op) + 1;
+
   // Find the op producing the array passed to this MatMul
   auto previous_op_it = model->operators.begin();
   bool found = false;
@@ -76,13 +83,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
   }
   Operator* previous_op = (found) ? previous_op_it->get() : nullptr;
 
-  // Construct the new FullyConnectedOperator.
-  auto* fc_op = new FullyConnectedOperator;
-  fc_op->outputs = matmul_op->outputs;
-
-  // Insert the newly constructed FullyConnectedOperator.
-  model->operators.emplace(matmul_it, fc_op) + 1;
-
   // Refresh iterator.
   matmul_it = model->operators.begin();
   for (; matmul_it != model->operators.end(); ++matmul_it) {
index 0c52f50..76e9a27 100644 (file)
@@ -74,7 +74,9 @@ void MakeGeneralGraphTransformationsSet(
   transformations->Add(new ResolveTensorFlowMatMul);
   transformations->Add(new FuseBinaryIntoPrecedingAffine);
   transformations->Add(new FuseBinaryIntoFollowingAffine);
-  transformations->Add(new ReorderActivationFunctions);
+  transformations->Add(new MergeReshapeIntoPrecedingTranspose);
+  transformations->Add(new ReorderElementwiseUnary);
+  transformations->Add(new ReorderReshapeTranspose);
   transformations->Add(new ResolveBatchNormalization);
   transformations->Add(new ResolveConstantBinaryOperator);
   transformations->Add(new ResolveConstantFill);
index 060c52e..c1eaba7 100644 (file)
@@ -1921,6 +1921,35 @@ bool IsDiscardableArray(const Model& model, const string& array_name) {
   return true;
 }
 
+bool ReshapeIsEquivalentToTranspose(const Model& model,
+                                    const TensorFlowReshapeOperator* op,
+                                    bool allow_extra_unary_dims) {
+  CHECK(!op->shape.empty());
+  CHECK(model.HasArray(op->inputs[0]));
+  CHECK(model.HasArray(op->outputs[0]));
+
+  const auto& input_array = model.GetArray(op->inputs[0]);
+  const auto& output_array = model.GetArray(op->outputs[0]);
+
+  CHECK(input_array.has_shape());
+  CHECK(output_array.has_shape());
+
+  std::vector<int> in_shape = input_array.shape().dims();
+  std::vector<int> out_shape = output_array.shape().dims();
+
+  // If the reshape changes the number of dimensions so it cannot be interpreted
+  // as a transpose.
+  if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
+    return false;
+  }
+
+  in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
+                 in_shape.end());
+  out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
+                  out_shape.end());
+  return in_shape == out_shape;
+}
+
 void CheckFinalDataTypesSatisfied(const Model& model) {
   for (const auto& array_entry : model.GetArrayMap()) {
     const auto& array = *array_entry.second;
index d3b7224..259ee7f 100644 (file)
@@ -169,10 +169,23 @@ void GetQuantizationParamsFromMinMax(const MinMax& minmax,
       ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
 }
 
+template <typename T>
+T ConvertOperator(Operator* o, OperatorType type) {
+  if (o != nullptr && o->type == type) {
+    return static_cast<T>(o);
+  }
+
+  return nullptr;
+}
+
 void CheckIsReadyForQuantization(const Model& model);
 void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
                                  double default_ranges_max);
 
+bool ReshapeIsEquivalentToTranspose(const Model& model,
+                                    const TensorFlowReshapeOperator* op,
+                                    bool allow_extra_unary_dims);
+
 inline int Offset(const Shape& shape, const std::vector<int>& indices) {
   DCHECK_EQ(shape.dimensions_count(), indices.size());
   const int dims_count = shape.dimensions_count();