ResolveConstantReshape transformation and fix for ResolveConstantTranspose.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 12 Apr 2018 21:31:08 +0000 (14:31 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 12 Apr 2018 21:33:53 +0000 (14:33 -0700)
PiperOrigin-RevId: 192670991

tensorflow/contrib/lite/toco/BUILD
tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc [new file with mode: 0644]
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc
tensorflow/contrib/lite/toco/tooling_util.h

index a05d719..4c8652d 100644 (file)
@@ -266,6 +266,7 @@ cc_library(
         "graph_transformations/resolve_constant_gather.cc",
         "graph_transformations/resolve_constant_random_uniform.cc",
         "graph_transformations/resolve_constant_range.cc",
+        "graph_transformations/resolve_constant_reshape.cc",
         "graph_transformations/resolve_constant_shape_or_rank.cc",
         "graph_transformations/resolve_constant_stack.cc",
         "graph_transformations/resolve_constant_strided_slice.cc",
index 80463ce..384bd85 100644 (file)
@@ -165,6 +165,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape)
 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose)
 DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant)
 DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions)
index 61477d5..e28d8cf 100644 (file)
@@ -41,8 +41,8 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
         ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) {
       transformation->AddMessageF(
           "%s is trivial because its input and output shapes are equal up to "
-          "extending "
-          "by 1's, and we are told to aggressively discard such Reshape ops.",
+          "extending by 1's, and we are told to aggressively discard such "
+          "Reshape ops.",
           LogName(op));
       return true;
     }
@@ -80,6 +80,7 @@ bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
   }
 
   if (!IsReshapeTrivial(*model, *reshape_op, this)) {
+    AddMessageF("%s is not trivial", LogName(*reshape_op));
     return false;
   }
 
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
new file mode 100644 (file)
index 0000000..7e7ad38
--- /dev/null
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// Resolves a constant reshape operation by copying the buffer.
+bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
+  auto it = model->operators.begin() + op_index;
+  const auto* base_op = it->get();
+  if (base_op->type != OperatorType::kTensorFlowReshape) {
+    return false;
+  }
+  const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op);
+
+  CHECK_EQ(op->inputs.size(), 2);
+  CHECK_EQ(op->outputs.size(), 1);
+
+  // We require constant inputs.
+  if (!IsConstantParameterArray(*model, op->inputs[0]) ||
+      !IsConstantParameterArray(*model, op->inputs[1])) {
+    return false;
+  }
+
+  auto& output_array = model->GetArray(op->outputs[0]);
+  if (output_array.data_type == ArrayDataType::kNone) {
+    // Yield until the output type has been set by PropagateArrayDataTypes.
+    return false;
+  }
+  if (!output_array.has_shape()) {
+    // Yield until the output shape has been set by PropagateFixedShapes.
+    return false;
+  }
+
+  const Array& input_array = model->GetArray(op->inputs[0]);
+  if (!ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) {
+    AddMessageF("Constant reshape is non-trivial (%s -> %s)",
+                ShapeToString(input_array.shape()),
+                ShapeToString(output_array.shape()));
+    return false;
+  }
+
+  CHECK(!output_array.buffer);
+  switch (input_array.data_type) {
+    case ArrayDataType::kBool:
+      CopyArrayBuffer<ArrayDataType::kBool>(input_array, &output_array);
+      break;
+    case ArrayDataType::kFloat:
+      CopyArrayBuffer<ArrayDataType::kFloat>(input_array, &output_array);
+      break;
+    case ArrayDataType::kInt8:
+      CopyArrayBuffer<ArrayDataType::kInt8>(input_array, &output_array);
+      break;
+    case ArrayDataType::kUint8:
+      CopyArrayBuffer<ArrayDataType::kUint8>(input_array, &output_array);
+      break;
+    case ArrayDataType::kInt16:
+      CopyArrayBuffer<ArrayDataType::kInt16>(input_array, &output_array);
+      break;
+    case ArrayDataType::kUint16:
+      CopyArrayBuffer<ArrayDataType::kUint16>(input_array, &output_array);
+      break;
+    case ArrayDataType::kInt32:
+      CopyArrayBuffer<ArrayDataType::kInt32>(input_array, &output_array);
+      break;
+    case ArrayDataType::kUint32:
+      CopyArrayBuffer<ArrayDataType::kUint32>(input_array, &output_array);
+      break;
+    case ArrayDataType::kInt64:
+      CopyArrayBuffer<ArrayDataType::kInt64>(input_array, &output_array);
+      break;
+    case ArrayDataType::kUint64:
+      CopyArrayBuffer<ArrayDataType::kUint64>(input_array, &output_array);
+      break;
+    case ArrayDataType::kString:
+      CopyArrayBuffer<ArrayDataType::kString>(input_array, &output_array);
+      break;
+    default:
+      LOG(FATAL) << "Unsupported data type: "
+                 << ArrayDataTypeName(input_array.data_type);
+      return false;
+  }
+
+  AddMessageF("Resolving constant reshape of %s", LogName(*op));
+
+  if (input_array.minmax) {
+    output_array.GetOrCreateMinMax() = input_array.GetMinMax();
+  }
+  if (input_array.quantization_params) {
+    output_array.GetOrCreateQuantizationParams() =
+        input_array.GetQuantizationParams();
+  }
+
+  // Erase input arrays if no longer used.
+  for (const auto& input : op->inputs) {
+    if (IsDiscardableArray(*model, input) &&
+        CountOpsWithInput(*model, input) == 1) {
+      model->EraseArray(input);
+    }
+  }
+
+  // Erase the operator.
+  model->operators.erase(it);
+  return true;
+}
+
+}  // namespace toco
index 4f984bf..1fd2031 100644 (file)
@@ -131,6 +131,10 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
   if (input_array.minmax) {
     output_array.GetOrCreateMinMax() = input_array.GetMinMax();
   }
+  if (input_array.quantization_params) {
+    output_array.GetOrCreateQuantizationParams() =
+        input_array.GetQuantizationParams();
+  }
 
   if (op->perm.empty()) {
     // Yield until perm has been populated by ResolveTransposeAttributes.
@@ -164,6 +168,8 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
       break;
   }
 
+  AddMessageF("Resolving constant transpose of %s", LogName(*op));
+
   // Erase input arrays if no longer used.
   for (const auto& input : op->inputs) {
     if (IsDiscardableArray(*model, input) &&
index 1ab0a6f..5ba093a 100644 (file)
@@ -83,6 +83,7 @@ void MakeGeneralGraphTransformationsSet(
   transformations->Add(new ResolveConstantGather);
   transformations->Add(new ResolveConstantRandomUniform);
   transformations->Add(new ResolveConstantRange);
+  transformations->Add(new ResolveConstantReshape);
   transformations->Add(new ResolveConstantStack);
   transformations->Add(new ResolveConstantStridedSlice);
   transformations->Add(new ResolveConstantTranspose);
index bd2d5f7..224df99 100644 (file)
@@ -1084,23 +1084,30 @@ void InsertCopyOperator(Model* model, const string& source_array_name,
   model->operators.emplace_back(copy_op);
 }
 
-namespace {
-template <ArrayDataType A>
-void CopyArrayBuffer(const Array& source_array, Array* target_array) {
-  if (source_array.buffer) {
-    const auto& source_buffer = source_array.GetBuffer<A>();
-    auto& target_buffer = target_array->GetMutableBuffer<A>();
-    target_buffer.data = source_buffer.data;
-  }
-}
-}  // namespace
-
 void CloneArray(Model* model, const string& source_array_name,
                 const string& target_array_name) {
   CHECK(!model->HasArray(target_array_name));
   const Array& source_array = model->GetArray(source_array_name);
   Array& target_array = model->GetOrCreateArray(target_array_name);
 
+  if (source_array.minmax) {
+    const auto& smm = source_array.GetMinMax();
+    auto& tmm = target_array.GetOrCreateMinMax();
+    tmm.min = smm.min;
+    tmm.max = smm.max;
+  }
+
+  if (source_array.quantization_params) {
+    const auto& sqp = source_array.GetQuantizationParams();
+    auto& tqp = target_array.GetOrCreateQuantizationParams();
+    tqp.zero_point = sqp.zero_point;
+    tqp.scale = sqp.scale;
+  }
+
+  target_array.data_type = source_array.data_type;
+  target_array.final_data_type = source_array.final_data_type;
+  target_array.copy_shape(source_array.shape());
+
   switch (source_array.data_type) {
     case ArrayDataType::kBool:
       CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
@@ -1140,25 +1147,6 @@ void CloneArray(Model* model, const string& source_array_name,
                  << ArrayDataTypeName(source_array.data_type);
       return;
   }
-
-  if (source_array.minmax) {
-    const auto& smm = source_array.GetMinMax();
-    auto& tmm = target_array.GetOrCreateMinMax();
-    tmm.min = smm.min;
-    tmm.max = smm.max;
-  }
-
-  if (source_array.quantization_params) {
-    const auto& sqp = source_array.GetQuantizationParams();
-    auto& tqp = target_array.GetOrCreateQuantizationParams();
-    tqp.zero_point = sqp.zero_point;
-    tqp.scale = sqp.scale;
-  }
-
-  target_array.data_type = source_array.data_type;
-  target_array.final_data_type = source_array.final_data_type;
-
-  target_array.copy_shape(source_array.shape());
 }
 
 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
index dfd8117..ed0ecd4 100644 (file)
@@ -147,6 +147,23 @@ void FixNoOrphanedArray(Model* model);
 // Fixes input/output arrays that may have issues during export or inference.
 void FixEdgeArrays(Model* model);
 
+// Copies the contents of an array into another.
+// Expects that the shape and data type match.
+template <ArrayDataType A>
+void CopyArrayBuffer(const Array& source_array, Array* target_array) {
+  int source_buffer_size = RequiredBufferSizeForShape(source_array.shape());
+  int target_buffer_size = RequiredBufferSizeForShape(target_array->shape());
+  CHECK_EQ(source_buffer_size, target_buffer_size)
+      << "Buffer sizes must match in element count";
+  CHECK(source_array.data_type == target_array->data_type)
+      << "Data types must match";
+  if (source_array.buffer) {
+    const auto& source_buffer = source_array.GetBuffer<A>();
+    auto& target_buffer = target_array->GetMutableBuffer<A>();
+    target_buffer.data = source_buffer.data;
+  }
+}
+
 // Inserts a no-op reshape operator between the source array and the target
 // array. This effectively just copies the data.
 void InsertCopyOperator(Model* model, const string& source_array_name,