Fixing constant output arrays by inserting synthetic reshapes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Mar 2018 18:29:00 +0000 (11:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Mar 2018 18:32:50 +0000 (11:32 -0700)
PiperOrigin-RevId: 189368237

tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc
tensorflow/contrib/lite/toco/tooling_util.h

index 024335b..ca66110 100644 (file)
@@ -289,6 +289,10 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
     EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
   }
 
+  // Fix any issues with IO edges. This must happen after any transform that
+  // may modify the structure of the edges.
+  FixEdgeArrays(model);
+
   LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
 
   if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
index e70291a..2362206 100644 (file)
@@ -1047,6 +1047,117 @@ void CheckModelCounts(const Model& model) {
   }
 }
 
+void FixEdgeArrays(Model* model) {
+  for (const string& output_array_name : model->flags.output_arrays()) {
+    if (!GetOpWithOutput(*model, output_array_name)) {
+      // Output has no operator producing it. Change that by inserting a copy.
+      LOG(WARNING) << "Fixing constant output array " << output_array_name
+                   << " by inserting a copy. This is not optimal.";
+      string intermediate_array_name =
+          AvailableArrayName(*model, output_array_name + "_copy");
+      CloneArray(model, output_array_name, intermediate_array_name);
+      InsertCopyOperator(model, intermediate_array_name, output_array_name);
+    }
+  }
+}
+
+void InsertCopyOperator(Model* model, const string& source_array_name,
+                        const string& target_array_name) {
+  // Drop constant data from the target array as the copy will be done at
+  // runtime.
+  Array& target_array = model->GetOrCreateArray(target_array_name);
+  target_array.buffer.reset();
+
+  // Reshape to the same size. This should be a no-op.
+  const Array& source_array = model->GetArray(source_array_name);
+  std::vector<int> shape = source_array.shape().dims();
+
+  // Insert copy operator.
+  auto* copy_op = new TensorFlowReshapeOperator;
+  copy_op->inputs = {
+      source_array_name,
+      CreateInt32Array(model, target_array_name + "_copy_shape", shape)};
+  copy_op->outputs = {target_array_name};
+  model->operators.emplace_back(copy_op);
+}
+
+namespace {
+template <ArrayDataType A>
+void CopyArrayBuffer(const Array& source_array, Array* target_array) {
+  if (source_array.buffer) {
+    const auto& source_buffer = source_array.GetBuffer<A>();
+    auto& target_buffer = target_array->GetMutableBuffer<A>();
+    target_buffer.data = source_buffer.data;
+  }
+}
+}  // namespace
+
+void CloneArray(Model* model, const string& source_array_name,
+                const string& target_array_name) {
+  CHECK(!model->HasArray(target_array_name));
+  const Array& source_array = model->GetArray(source_array_name);
+  Array& target_array = model->GetOrCreateArray(target_array_name);
+
+  switch (source_array.data_type) {
+    case ArrayDataType::kBool:
+      CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
+      break;
+    case ArrayDataType::kFloat:
+      CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
+      break;
+    case ArrayDataType::kInt8:
+      CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
+      break;
+    case ArrayDataType::kUint8:
+      CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
+      break;
+    case ArrayDataType::kInt16:
+      CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
+      break;
+    case ArrayDataType::kUint16:
+      CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
+      break;
+    case ArrayDataType::kInt32:
+      CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
+      break;
+    case ArrayDataType::kUint32:
+      CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
+      break;
+    case ArrayDataType::kInt64:
+      CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
+      break;
+    case ArrayDataType::kUint64:
+      CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
+      break;
+    case ArrayDataType::kString:
+      CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
+      break;
+    default:
+      LOG(FATAL) << "Unsupported data type: "
+                 << ArrayDataTypeName(source_array.data_type);
+      return;
+  }
+
+  if (source_array.minmax) {
+    const auto& smm = source_array.GetMinMax();
+    auto& tmm = target_array.GetOrCreateMinMax();
+    tmm.min = smm.min;
+    tmm.max = smm.max;
+  }
+
+  if (source_array.quantization_params) {
+    const auto& sqp = source_array.GetQuantizationParams();
+    auto& tqp = target_array.GetOrCreateQuantizationParams();
+    tqp.zero_point = sqp.zero_point;
+    tqp.scale = sqp.scale;
+  }
+
+  target_array.data_type = source_array.data_type;
+  target_array.final_data_type = source_array.final_data_type;
+
+  target_array.copy_shape(source_array.shape());
+}
+
 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
                    std::vector<int>* out_dims) {
   CHECK(out_dims->empty());
index 05360e3..d3b7224 100644 (file)
@@ -144,6 +144,18 @@ void FixOperatorOrdering(Model* model);
 void FixNoMissingArray(Model* model);
 void FixNoOrphanedArray(Model* model);
 
+// Fixes input/output arrays that may have issues during export or inference.
+void FixEdgeArrays(Model* model);
+
+// Inserts a no-op reshape operator between the source array and the target
+// array. This effectively just copies the data.
+void InsertCopyOperator(Model* model, const string& source_array_name,
+                        const string& target_array_name);
+
+// Clones an array with all data and parameters.
+void CloneArray(Model* model, const string& source_array_name,
+                const string& target_array_name);
+
 void ResolveModelFlags(const ModelFlags& model_flags, Model* model);
 
 template <ArrayDataType A>