"graph_transformations/identify_relu1.cc",
"graph_transformations/lstm_utils.cc",
"graph_transformations/make_initial_dequantize_operator.cc",
+ "graph_transformations/merge_reshape_into_preceding_transpose.cc",
"graph_transformations/propagate_activation_function_into_constants.cc",
"graph_transformations/propagate_array_data_types.cc",
"graph_transformations/propagate_fixed_sizes.cc",
"graph_transformations/remove_trivial_reshape.cc",
"graph_transformations/remove_trivial_slice.cc",
"graph_transformations/remove_unused_op.cc",
- "graph_transformations/reorder_activation_functions.cc",
+ "graph_transformations/reorder_elementwise_unary.cc",
+ "graph_transformations/reorder_reshape_transpose.cc",
"graph_transformations/resolve_batch_normalization.cc",
"graph_transformations/resolve_batch_to_space_nd_attributes.cc",
"graph_transformations/resolve_constant_binary.cc",
DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
+DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax)
-DECLARE_GRAPH_TRANSFORMATION(ReorderActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary)
+DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose)
DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool OperatorReady(const Model& model, const Operator* op) {
+ if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) ||
+ !model.HasArray(op->outputs[0])) {
+ // Arrays are missing.
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[0]).has_shape() ||
+ !model.GetArray(op->outputs[0]).has_shape()) {
+ // Input and output needs the shape.
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[1]).buffer) {
+ // Buffer needs to be a constant.
+ return false;
+ }
+
+ return true;
+}
+
+// Returns whether the reshape could be a transpose.
+std::vector<int32> ReshapeToTranspose(const Model& model,
+ const TensorFlowReshapeOperator* op) {
+ CHECK(!op->shape.empty());
+ CHECK(model.HasArray(op->inputs[0]));
+ CHECK(model.HasArray(op->outputs[0]));
+
+ const auto& input_array = model.GetArray(op->inputs[0]);
+ const auto& output_array = model.GetArray(op->outputs[0]);
+
+ CHECK(input_array.has_shape());
+ CHECK(output_array.has_shape());
+
+ std::vector<int> in_shape = input_array.shape().dims();
+ std::vector<int> out_shape = output_array.shape().dims();
+
+ std::vector<int> one_indices;
+ std::vector<int> not_one_indices;
+
+ // Separate into one indices and not one indices.
+ for (int i = 0; i < in_shape.size(); i++) {
+ if (in_shape[i] == 1) {
+ one_indices.push_back(i);
+ } else {
+ not_one_indices.push_back(i);
+ }
+ }
+
+ // Reorder the vertices.
+ std::vector<int> perm;
+ perm.reserve(in_shape.size());
+ int one_index = 0;
+ int not_one_index = 0;
+ for (const auto val : out_shape) {
+ if (val == 1) {
+ perm.push_back(one_indices[one_index]);
+ one_index++;
+ } else {
+ perm.push_back(not_one_indices[not_one_index]);
+ not_one_index++;
+ }
+ }
+
+ return perm;
+}
+
+} // namespace
+
+// When a transpose is fed into a reshape, it is possible for the two operators
+// to be merged if the reshape does not affect memory ordering and does not
+// affects the number of dimensions. This only occurs when only unary dimensions
+// are shifting position.
+bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
+ std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
+ it->get(), OperatorType::kTensorFlowReshape);
+
+ if (reshape_op == nullptr) {
+ return false;
+ }
+
+ if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
+ return false;
+ }
+
+ const string intermediate_name = reshape_op->inputs[0];
+ const string output_name = reshape_op->outputs[0];
+
+ // Guarantee the input is only consume by the reshape.
+ if (CountOpsWithInput(*model, intermediate_name) != 1) {
+ return false;
+ }
+
+ // Check for the parent operator.
+ const auto& transpose_it = FindOpWithOutput(*model, intermediate_name);
+ if (transpose_it == model->operators.end()) {
+ return false;
+ }
+
+ // Find the parent operator and guarantee it is a transpose.
+ TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
+ transpose_it->get(), OperatorType::kTranspose);
+
+ if (transpose_op == nullptr) {
+ return false;
+ }
+
+ if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
+ return false;
+ }
+
+ if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
+ false /*allow_extra_unary_dimensions*/)) {
+ return false;
+ }
+
+ // Check that the intermediate is not an output array.
+ if (!IsDiscardableArray(*model, intermediate_name)) {
+ AddMessageF(
+ "Cannot fuse %s and %s as it would invalidate the transpose "
+ "output array.",
+ LogName(*transpose_op), LogName(*reshape_op));
+ return false;
+ }
+
+ AddMessageF("Merging operations %s and %s", LogName(*transpose_op),
+ LogName(*reshape_op));
+
+ // const auto& intermediate_array = model->GetArray(intermediate_name);
+ // const auto& output_array = model->GetArray(output_name);
+
+ auto merged_perm = ReshapeToTranspose(*model, reshape_op);
+
+ // Combine the permutations.
+ const auto& transpose_perm = transpose_op->perm;
+ for (int i = 0; i < merged_perm.size(); i++) {
+ merged_perm[i] = transpose_perm[merged_perm[i]];
+ }
+
+ // Remove the reshape as passthrough operation.
+ if (!RemoveTrivialPassthroughOp(this, model, op_index)) {
+ return false;
+ }
+
+ // Update transpose_op's constant buffer to contain the new permutation.
+ model->GetArray(transpose_op->inputs[1])
+ .GetMutableBuffer<ArrayDataType::kInt32>()
+ .data = merged_perm;
+ transpose_op->perm = merged_perm;
+
+ // transpose_ops's shape will likely has changed.
+ model->GetArray(transpose_op->outputs[0]).clear_shape();
+
+ return true;
+}
+
+} // namespace toco
+++ /dev/null
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/runtime/types.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace toco {
-
-bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) {
- const auto ac_it = model->operators.begin() + op_index;
- std::unique_ptr<Operator>& ac_op = *ac_it;
- DCHECK(ac_op);
-
- if (ac_op->type != OperatorType::kRelu6 &&
- ac_op->type != OperatorType::kRelu1 &&
- ac_op->type != OperatorType::kRelu) {
- return false;
- }
-
- auto exchange_it = FindOpWithOutput(*model, ac_op->inputs[0]);
- if (exchange_it == model->operators.end()) return false;
- // Find the op producing the array passed to this activation function
- std::unique_ptr<Operator>& exchange_op = *exchange_it;
- DCHECK(exchange_op);
-
- // Allow activation functions to move up over any operator that does not
- // change the values.
- switch (exchange_op->type) {
- case OperatorType::kExpandDims:
- case OperatorType::kSqueeze:
- case OperatorType::kTensorFlowReshape:
- case OperatorType::kTranspose:
- break;
- default:
- return false;
- }
-
- DCHECK_EQ(exchange_op->outputs[0], ac_op->inputs[0]);
- const auto exchange_op_input = exchange_op->inputs[0];
- const auto intermediate_array = exchange_op->outputs[0];
- const auto ac_op_output = ac_op->outputs[0];
-
- int count_ops_consuming_output =
- CountOpsWithInput(*model, intermediate_array);
- DCHECK_GE(count_ops_consuming_output, 1);
- if (count_ops_consuming_output > 1) {
- AddMessageF(
- "Not exchanging activation function with %s because it is consumed by "
- "more than 1 other operator",
- LogName(*exchange_op));
- return false;
- }
-
- // If the ac_op was originally producing an output_array we can't trivially
- // reorder as otherwise the output array name would change and break
- // downstream assumptions. To work around that we perform some renaming below
- // in that case at the cost of a bit more confusing array names in this rare
- // case.
- bool is_ac_op_output =
- std::find(model->flags.output_arrays().begin(),
- model->flags.output_arrays().end(),
- ac_op_output) != model->flags.output_arrays().end();
- if (is_ac_op_output) {
- // To preserve the output array name of the activation function we need to
- // create a temporary to use to pass between ac->ex.
- //
- // Original:
- // (a) -> EX -> (b) -> AC -> (c)
- // Now:
- // (a) -> AC -> (c') -> EX -> (c)
- AddMessageF(
- "Exchanging activation function %s with %s but renaming to preserve "
- "output array %s",
- LogName(*ac_op), LogName(*exchange_op), ac_op->outputs[0]);
-
- auto renamed_ac_op_output =
- AvailableArrayName(*model, ac_op_output + "_exchange");
- ac_op->inputs[0] = exchange_op_input;
- ac_op->outputs[0] = renamed_ac_op_output;
- model->EraseArray(exchange_op->outputs[0]);
- exchange_op->inputs[0] = renamed_ac_op_output;
- exchange_op->outputs[0] = ac_op_output;
- } else {
- // Simply swap the order and update consumers to use the exchange_op output
- // array (b).
- //
- // Original:
- // (a) -> EX -> (b) -> AC -> (c)
- // Now:
- // (a) -> AC -> (c) -> EX -> (b)
- AddMessageF("Exchanging activation function %s with %s", LogName(*ac_op),
- LogName(*exchange_op));
-
- Operator* consumer = GetFirstOpWithInput(*model, ac_op_output);
- while (consumer) {
- for (int i = 0; i < consumer->inputs.size(); ++i) {
- if (consumer->inputs[i] == ac_op_output) {
- consumer->inputs[i] = intermediate_array;
- }
- }
- consumer = GetFirstOpWithInput(*model, ac_op_output);
- }
- ac_op->inputs[0] = exchange_op_input;
- exchange_op->inputs[0] = ac_op_output;
- }
-
- // Clear shapes; this will allow shape propagation to fix the sizes for us.
- model->GetOrCreateArray(ac_op->outputs[0]).clear_shape();
- model->GetOrCreateArray(exchange_op->outputs[0]).clear_shape();
-
- // Finally, reorder operators. Note that this only works when there are no
- // other direct descendents of the exchange_op.
- ac_op.swap(exchange_op);
-
- return true;
-}
-
-} // namespace toco
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool IsElementwiseOperator(OperatorType optype) {
+ switch (optype) {
+ case OperatorType::kCast:
+ case OperatorType::kExp:
+ case OperatorType::kFloor:
+ case OperatorType::kNeg:
+ case OperatorType::kRelu:
+ case OperatorType::kRelu1:
+ case OperatorType::kRelu6:
+ case OperatorType::kTanh:
+ case OperatorType::kTensorFlowSqrt:
+ case OperatorType::kTensorFlowSquare:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool IsMoveOperator(OperatorType optype) {
+ switch (optype) {
+ case OperatorType::kDepthToSpace:
+ case OperatorType::kExpandDims:
+ case OperatorType::kSpaceToDepth:
+ case OperatorType::kSqueeze:
+ case OperatorType::kTensorFlowReshape:
+ case OperatorType::kTranspose:
+ return true;
+ default:
+ return false;
+ }
+}
+
+} // namespace
+
+// Swap elementwise operators such that all value operators occur before all
+// element move operators, e.g. negation then transpose.
+bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
+ const auto element_op_it = model->operators.begin() + op_index;
+ std::unique_ptr<Operator>& element_op = *element_op_it;
+ if (!IsElementwiseOperator(element_op->type)) {
+ return false;
+ }
+
+ const string intermediate_name = element_op->inputs[0];
+ auto it = FindOpWithOutput(*model, intermediate_name);
+ if (it == model->operators.end()) {
+ AddMessageF("No preceding operator");
+ return false;
+ }
+
+ std::unique_ptr<Operator>& move_op = *it;
+ if (!IsMoveOperator(move_op->type)) {
+ AddMessageF("Preceding operator is not a move operator");
+ return false;
+ }
+
+ if (CountOpsWithInput(*model, intermediate_name) != 1) {
+ AddMessageF("Input %s used elsewhere", intermediate_name);
+ return false;
+ }
+
+ // Check that the intermediate is discardable.
+ if (!IsDiscardableArray(*model, intermediate_name)) {
+ AddMessageF(
+ "Cannot swap elementwise as it would invalidate %s which is "
+ "an output array.",
+ intermediate_name);
+ return false;
+ }
+
+ // op->inputs may change so we need to keep a value by copy.
+ const string input_name = move_op->inputs[0];
+ const string output_name = element_op->outputs[0];
+
+ AddMessageF("Swapping around operators with %s and %s", LogName(*element_op),
+ LogName(*move_op));
+
+ // If the output array is an exit node for the graph then we need to retain
+ // the name as an output node. This makes the naming scheme a little confusing
+ // but is required in this rare case.
+ if (!IsDiscardableArray(*model, output_name)) {
+ // The output name of the sequence needs to stay static, so create a new
+ // array new use for the intermediate.
+ const auto new_intermediate_name =
+ AvailableArrayName(*model, element_op->outputs[0] + "_reorder");
+ AddMessageF("Adding new array %s to preserve output array name %s",
+ new_intermediate_name, output_name);
+
+ element_op->inputs[0] = input_name;
+ element_op->outputs[0] = new_intermediate_name;
+ model->EraseArray(intermediate_name);
+ move_op->inputs[0] = new_intermediate_name;
+ move_op->outputs[0] = output_name;
+ } else {
+ // The intermediate array is now the output array.
+ for (int i = 0; i < model->operators.size(); i++) {
+ Operator* consumer = model->operators[i].get();
+ for (int j = 0; j < consumer->inputs.size(); j++) {
+ if (consumer->inputs[j] == output_name) {
+ consumer->inputs[j] = intermediate_name;
+ }
+ }
+ }
+
+ element_op->inputs[0] = input_name;
+ move_op->inputs[0] = output_name;
+ }
+
+ // Reset both arrays as shape, type, min/max, etc can all change because of
+ // the position swap.
+ model->EraseArray(element_op->outputs[0]);
+ model->EraseArray(move_op->outputs[0]);
+
+ // Reconstruct.
+ model->GetOrCreateArray(element_op->outputs[0]);
+ model->GetOrCreateArray(move_op->outputs[0]);
+
+ // Swap the order of the operators.
+ element_op.swap(move_op);
+
+ return true;
+}
+
+} // namespace toco
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool OperatorReady(const Model& model, const Operator* op) {
+ if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) ||
+ !model.HasArray(op->outputs[0])) {
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[0]).has_shape() ||
+ !model.GetArray(op->outputs[0]).has_shape()) {
+ // Input and output needs the shape.
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[1]).buffer) {
+ // Buffer needs to be a constant.
+ return false;
+ }
+
+ return true;
+}
+
+// Utility function to filter out a value.
+void Filter(std::vector<int>* vec, int value) {
+ vec->erase(std::remove(vec->begin(), vec->end(), value), vec->end());
+}
+
+// Computes a new permutation used to swap a reshape-transpose to a
+// transpose-reshape. In this case the permutation operates on the intermediate
+// shape.
+std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
+ std::vector<int> intermediate_dims,
+ std::vector<int> perm) {
+ // These are the major axis of the input.
+ std::vector<int> input_indices;
+ for (int i = 0; i < input_dims.size(); i++) {
+ if (input_dims[i] != 1) {
+ input_indices.push_back(i);
+ }
+ }
+
+ // This maps which indices of the input produced the intermediate indices for
+ // non-unary dimensions.
+ std::unordered_map<int, int> intermediate_to_input_indices_map;
+ for (int i = 0; i < intermediate_dims.size(); i++) {
+ if (intermediate_dims[i] != 1) {
+ intermediate_to_input_indices_map[i] =
+ input_indices[intermediate_to_input_indices_map.size()];
+ }
+ }
+
+ // Translate the transpose permutation to a new permutation starting with the
+ // major indices.
+ std::vector<int> new_perm;
+ new_perm.reserve(input_dims.size());
+ for (int i = 0; i < perm.size(); i++) {
+ if (intermediate_dims[perm[i]] == 1) continue;
+
+ new_perm.push_back(intermediate_to_input_indices_map[perm[i]]);
+ }
+
+ // Fill the rest of the transpose in with the ones.
+ for (int index = 0; index < input_dims.size(); index++) {
+ if (input_dims[index] == 1) {
+ new_perm.push_back(index);
+ }
+ }
+
+ CHECK_EQ(new_perm.size(), input_dims.size());
+ return new_perm;
+}
+
+} // namespace
+
+// Swaps reshape-transpose to transpose-reshape whenever possible. This is
+// possible when the reshape does not affect memory ordering.
+bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
+ auto transpose_it = model->operators.begin() + op_index;
+
+ TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
+ transpose_it->get(), OperatorType::kTranspose);
+
+ if (transpose_op == nullptr) {
+ return false;
+ }
+
+ if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
+ // Wait for values to propagate.
+ return false;
+ }
+
+ // Find the operator that produces the transpose op.
+ auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]);
+ if (reshape_it == model->operators.end()) {
+ return false;
+ }
+
+ TensorFlowReshapeOperator* reshape_op =
+ ConvertOperator<TensorFlowReshapeOperator*>(
+ reshape_it->get(), OperatorType::kTensorFlowReshape);
+ if (reshape_op == nullptr) {
+ return false;
+ }
+
+ // Ignore if the reshape is uninitialized.
+ if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
+ return false;
+ }
+
+ // Need to copy to keep static if permutated.
+ const string input_name = reshape_op->inputs[0];
+ const string intermediate_name = reshape_op->outputs[0];
+ const string output_name = transpose_op->outputs[0];
+
+ // Intermediate should not be consumed by any other operators.
+ if (CountOpsWithInput(*model, intermediate_name) != 1) {
+ AddMessageF("Input %s used elsewhere", intermediate_name);
+ return false;
+ }
+
+ // Check that the intermediate is not an output array.
+ if (!IsDiscardableArray(*model, intermediate_name)) {
+ AddMessageF(
+ "Cannot reorder reshape-transpose as it would invalidate %s which is "
+ "an output array.",
+ intermediate_name);
+ return false;
+ }
+
+ // Get the arrays.
+ const auto& input_array = model->GetArray(input_name);
+ const auto& intermediate_array = model->GetArray(intermediate_name);
+ const auto& output_array = model->GetArray(output_name);
+
+ // Get the shapes of each array.
+ Shape input_shape = input_array.shape();
+ Shape intermediate_shape = intermediate_array.shape();
+ Shape output_shape = output_array.shape();
+
+ // Assign ids to non-unary indices.
+ std::vector<int> input_dims = input_shape.dims();
+ std::vector<int> intermediate_dims = intermediate_shape.dims();
+ std::vector<int> output_dims = output_shape.dims();
+
+ // If the reshape is equivalent to a transpose with fewer/more unary
+ // dimensions then it can be moved between the transpose.
+ if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
+ true /*allow_extra_unary_dims*/)) {
+ return false;
+ }
+
+ if (!IsDiscardableArray(*model, output_name)) {
+ // The output name of the sequence needs to stay static, so create a new
+ // array new use for the intermediate.
+ const auto new_intermediate_name =
+ AvailableArrayName(*model, transpose_op->outputs[0] + "_exchange");
+ AddMessageF("Adding new array %s to preserve output array name %s",
+ new_intermediate_name, transpose_op->outputs[0]);
+ transpose_op->inputs[0] = input_name;
+ transpose_op->outputs[0] = new_intermediate_name;
+ reshape_op->inputs[0] = new_intermediate_name;
+ reshape_op->outputs[0] = output_name;
+ model->EraseArray(intermediate_name);
+ } else {
+ // The intermediate array is now the output array.
+ for (int i = 0; i < model->operators.size(); i++) {
+ Operator* consumer = model->operators[i].get();
+ for (int j = 0; j < consumer->inputs.size(); j++) {
+ if (consumer->inputs[j] == output_name) {
+ consumer->inputs[j] = intermediate_name;
+ }
+ }
+ }
+
+ transpose_op->inputs[0] = input_name;
+ reshape_op->inputs[0] = output_name;
+ }
+
+ // If transposes constant buffer is used elsewhere, make a new copy.
+ if (CountOpsWithInput(*model, transpose_op->inputs[1]) != 1) {
+ transpose_op->inputs[1] =
+ AvailableArrayName(*model, transpose_op->inputs[1] + "_copy");
+ }
+
+ // Make the new transpose permutation.
+ const std::vector<int> new_perm =
+ ComputeNewPerm(input_dims, intermediate_dims, transpose_op->perm);
+ CHECK_EQ(input_dims.size(), new_perm.size());
+
+ auto& transpose_array = model->GetOrCreateArray(transpose_op->inputs[1]);
+ transpose_array.GetMutableBuffer<ArrayDataType::kInt32>().data = new_perm;
+ *(transpose_array.mutable_shape()->mutable_dims()) = {
+ static_cast<int>(new_perm.size())};
+ transpose_op->perm = new_perm;
+
+ // If the reshape's constant buffer is reused, create a new one.
+ if (CountOpsWithInput(*model, reshape_op->inputs[1]) != 1) {
+ reshape_op->inputs[1] =
+ AvailableArrayName(*model, reshape_op->inputs[1] + "_copy");
+ }
+
+ // We need to modify the reshape input array to target the new output size.
+ auto& reshape_array = model->GetOrCreateArray(reshape_op->inputs[1]);
+ reshape_array.GetMutableBuffer<ArrayDataType::kInt32>().data = output_dims;
+ *(reshape_array.mutable_shape()->mutable_dims()) = {
+ static_cast<int>(output_shape.dimensions_count())};
+ reshape_op->shape.clear();
+
+ AddMessageF("Swapping around operators between %s and %s", input_name,
+ output_name);
+
+ model->GetOrCreateArray(transpose_op->outputs[0]).clear_shape();
+ model->GetOrCreateArray(reshape_op->outputs[0]).clear_shape();
+
+ // Swap the order of the operators.
+ transpose_it->swap(*reshape_it);
+
+ return true;
+}
+
+} // namespace toco
string input_lhs = matmul_op->inputs[0];
string input_rhs = transpose_op->outputs[0];
+ // Construct the new FullyConnectedOperator.
+ auto* fc_op = new FullyConnectedOperator;
+ fc_op->outputs = matmul_op->outputs;
+
+ // Insert the newly constructed FullyConnectedOperator.
+ model->operators.emplace(matmul_it, fc_op) + 1;
+
// Find the op producing the array passed to this MatMul
auto previous_op_it = model->operators.begin();
bool found = false;
}
Operator* previous_op = (found) ? previous_op_it->get() : nullptr;
- // Construct the new FullyConnectedOperator.
- auto* fc_op = new FullyConnectedOperator;
- fc_op->outputs = matmul_op->outputs;
-
- // Insert the newly constructed FullyConnectedOperator.
- model->operators.emplace(matmul_it, fc_op) + 1;
-
// Refresh iterator.
matmul_it = model->operators.begin();
for (; matmul_it != model->operators.end(); ++matmul_it) {
transformations->Add(new ResolveTensorFlowMatMul);
transformations->Add(new FuseBinaryIntoPrecedingAffine);
transformations->Add(new FuseBinaryIntoFollowingAffine);
- transformations->Add(new ReorderActivationFunctions);
+ transformations->Add(new MergeReshapeIntoPrecedingTranspose);
+ transformations->Add(new ReorderElementwiseUnary);
+ transformations->Add(new ReorderReshapeTranspose);
transformations->Add(new ResolveBatchNormalization);
transformations->Add(new ResolveConstantBinaryOperator);
transformations->Add(new ResolveConstantFill);
return true;
}
+bool ReshapeIsEquivalentToTranspose(const Model& model,
+ const TensorFlowReshapeOperator* op,
+ bool allow_extra_unary_dims) {
+ CHECK(!op->shape.empty());
+ CHECK(model.HasArray(op->inputs[0]));
+ CHECK(model.HasArray(op->outputs[0]));
+
+ const auto& input_array = model.GetArray(op->inputs[0]);
+ const auto& output_array = model.GetArray(op->outputs[0]);
+
+ CHECK(input_array.has_shape());
+ CHECK(output_array.has_shape());
+
+ std::vector<int> in_shape = input_array.shape().dims();
+ std::vector<int> out_shape = output_array.shape().dims();
+
+ // If the reshape changes the number of dimensions so it cannot be interpreted
+ // as a transpose.
+ if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
+ return false;
+ }
+
+ in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
+ in_shape.end());
+ out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
+ out_shape.end());
+ return in_shape == out_shape;
+}
+
void CheckFinalDataTypesSatisfied(const Model& model) {
for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = *array_entry.second;
::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
}
+template <typename T>
+T ConvertOperator(Operator* o, OperatorType type) {
+ if (o != nullptr && o->type == type) {
+ return static_cast<T>(o);
+ }
+
+ return nullptr;
+}
+
void CheckIsReadyForQuantization(const Model& model);
void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
double default_ranges_max);
+bool ReshapeIsEquivalentToTranspose(const Model& model,
+ const TensorFlowReshapeOperator* op,
+ bool allow_extra_unary_dims);
+
inline int Offset(const Shape& shape, const std::vector<int>& indices) {
DCHECK_EQ(shape.dimensions_count(), indices.size());
const int dims_count = shape.dimensions_count();