"graph_transformations/resolve_constant_gather.cc",
"graph_transformations/resolve_constant_random_uniform.cc",
"graph_transformations/resolve_constant_range.cc",
+ "graph_transformations/resolve_constant_reshape.cc",
"graph_transformations/resolve_constant_shape_or_rank.cc",
"graph_transformations/resolve_constant_stack.cc",
"graph_transformations/resolve_constant_strided_slice.cc",
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose)
DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions)
ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) {
transformation->AddMessageF(
"%s is trivial because its input and output shapes are equal up to "
- "extending "
- "by 1's, and we are told to aggressively discard such Reshape ops.",
+ "extending by 1's, and we are told to aggressively discard such "
+ "Reshape ops.",
LogName(op));
return true;
}
}
if (!IsReshapeTrivial(*model, *reshape_op, this)) {
+ AddMessageF("%s is not trivial", LogName(*reshape_op));
return false;
}
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// Resolves a constant reshape operation by copying the buffer.
+bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ const auto* base_op = it->get();
+ if (base_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+ const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op);
+
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ // We require constant inputs.
+ if (!IsConstantParameterArray(*model, op->inputs[0]) ||
+ !IsConstantParameterArray(*model, op->inputs[1])) {
+ return false;
+ }
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes.
+ return false;
+ }
+ if (!output_array.has_shape()) {
+ // Yield until the output shape has been set by PropagateFixedShapes.
+ return false;
+ }
+
+ const Array& input_array = model->GetArray(op->inputs[0]);
+ if (!ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) {
+ AddMessageF("Constant reshape is non-trivial (%s -> %s)",
+ ShapeToString(input_array.shape()),
+ ShapeToString(output_array.shape()));
+ return false;
+ }
+
+ CHECK(!output_array.buffer);
+ switch (input_array.data_type) {
+ case ArrayDataType::kBool:
+ CopyArrayBuffer<ArrayDataType::kBool>(input_array, &output_array);
+ break;
+ case ArrayDataType::kFloat:
+ CopyArrayBuffer<ArrayDataType::kFloat>(input_array, &output_array);
+ break;
+ case ArrayDataType::kInt8:
+ CopyArrayBuffer<ArrayDataType::kInt8>(input_array, &output_array);
+ break;
+ case ArrayDataType::kUint8:
+ CopyArrayBuffer<ArrayDataType::kUint8>(input_array, &output_array);
+ break;
+ case ArrayDataType::kInt16:
+ CopyArrayBuffer<ArrayDataType::kInt16>(input_array, &output_array);
+ break;
+ case ArrayDataType::kUint16:
+ CopyArrayBuffer<ArrayDataType::kUint16>(input_array, &output_array);
+ break;
+ case ArrayDataType::kInt32:
+ CopyArrayBuffer<ArrayDataType::kInt32>(input_array, &output_array);
+ break;
+ case ArrayDataType::kUint32:
+ CopyArrayBuffer<ArrayDataType::kUint32>(input_array, &output_array);
+ break;
+ case ArrayDataType::kInt64:
+ CopyArrayBuffer<ArrayDataType::kInt64>(input_array, &output_array);
+ break;
+ case ArrayDataType::kUint64:
+ CopyArrayBuffer<ArrayDataType::kUint64>(input_array, &output_array);
+ break;
+ case ArrayDataType::kString:
+ CopyArrayBuffer<ArrayDataType::kString>(input_array, &output_array);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type: "
+ << ArrayDataTypeName(input_array.data_type);
+ return false;
+ }
+
+ AddMessageF("Resolving constant reshape of %s", LogName(*op));
+
+ if (input_array.minmax) {
+ output_array.GetOrCreateMinMax() = input_array.GetMinMax();
+ }
+ if (input_array.quantization_params) {
+ output_array.GetOrCreateQuantizationParams() =
+ input_array.GetQuantizationParams();
+ }
+
+ // Erase input arrays if no longer used.
+ for (const auto& input : op->inputs) {
+ if (IsDiscardableArray(*model, input) &&
+ CountOpsWithInput(*model, input) == 1) {
+ model->EraseArray(input);
+ }
+ }
+
+ // Erase the operator.
+ model->operators.erase(it);
+ return true;
+}
+
+} // namespace toco
if (input_array.minmax) {
output_array.GetOrCreateMinMax() = input_array.GetMinMax();
}
+ if (input_array.quantization_params) {
+ output_array.GetOrCreateQuantizationParams() =
+ input_array.GetQuantizationParams();
+ }
if (op->perm.empty()) {
// Yield until perm has been populated by ResolveTransposeAttributes.
break;
}
+ AddMessageF("Resolving constant transpose of %s", LogName(*op));
+
// Erase input arrays if no longer used.
for (const auto& input : op->inputs) {
if (IsDiscardableArray(*model, input) &&
transformations->Add(new ResolveConstantGather);
transformations->Add(new ResolveConstantRandomUniform);
transformations->Add(new ResolveConstantRange);
+ transformations->Add(new ResolveConstantReshape);
transformations->Add(new ResolveConstantStack);
transformations->Add(new ResolveConstantStridedSlice);
transformations->Add(new ResolveConstantTranspose);
model->operators.emplace_back(copy_op);
}
-namespace {
-template <ArrayDataType A>
-void CopyArrayBuffer(const Array& source_array, Array* target_array) {
- if (source_array.buffer) {
- const auto& source_buffer = source_array.GetBuffer<A>();
- auto& target_buffer = target_array->GetMutableBuffer<A>();
- target_buffer.data = source_buffer.data;
- }
-}
-} // namespace
-
void CloneArray(Model* model, const string& source_array_name,
const string& target_array_name) {
CHECK(!model->HasArray(target_array_name));
const Array& source_array = model->GetArray(source_array_name);
Array& target_array = model->GetOrCreateArray(target_array_name);
+ if (source_array.minmax) {
+ const auto& smm = source_array.GetMinMax();
+ auto& tmm = target_array.GetOrCreateMinMax();
+ tmm.min = smm.min;
+ tmm.max = smm.max;
+ }
+
+ if (source_array.quantization_params) {
+ const auto& sqp = source_array.GetQuantizationParams();
+ auto& tqp = target_array.GetOrCreateQuantizationParams();
+ tqp.zero_point = sqp.zero_point;
+ tqp.scale = sqp.scale;
+ }
+
+ target_array.data_type = source_array.data_type;
+ target_array.final_data_type = source_array.final_data_type;
+ target_array.copy_shape(source_array.shape());
+
switch (source_array.data_type) {
case ArrayDataType::kBool:
CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
<< ArrayDataTypeName(source_array.data_type);
return;
}
-
- if (source_array.minmax) {
- const auto& smm = source_array.GetMinMax();
- auto& tmm = target_array.GetOrCreateMinMax();
- tmm.min = smm.min;
- tmm.max = smm.max;
- }
-
- if (source_array.quantization_params) {
- const auto& sqp = source_array.GetQuantizationParams();
- auto& tqp = target_array.GetOrCreateQuantizationParams();
- tqp.zero_point = sqp.zero_point;
- tqp.scale = sqp.scale;
- }
-
- target_array.data_type = source_array.data_type;
- target_array.final_data_type = source_array.final_data_type;
-
- target_array.copy_shape(source_array.shape());
}
void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
// Fixes input/output arrays that may have issues during export or inference.
void FixEdgeArrays(Model* model);
+// Copies the contents of an array into another.
+// Expects that the shape and data type match.
+template <ArrayDataType A>
+void CopyArrayBuffer(const Array& source_array, Array* target_array) {
+ int source_buffer_size = RequiredBufferSizeForShape(source_array.shape());
+ int target_buffer_size = RequiredBufferSizeForShape(target_array->shape());
+ CHECK_EQ(source_buffer_size, target_buffer_size)
+ << "Buffer sizes must match in element count";
+ CHECK(source_array.data_type == target_array->data_type)
+ << "Data types must match";
+ if (source_array.buffer) {
+ const auto& source_buffer = source_array.GetBuffer<A>();
+ auto& target_buffer = target_array->GetMutableBuffer<A>();
+ target_buffer.data = source_buffer.data;
+ }
+}
+
// Inserts a no-op reshape operator between the source array and the target
// array. This effectively just copies the data.
void InsertCopyOperator(Model* model, const string& source_array_name,