Allowing trivial passthrough ops to be turned into reshapes when they otherwise canno...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 May 2018 22:44:13 +0000 (15:44 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 22:47:09 +0000 (15:47 -0700)
PiperOrigin-RevId: 196041444

tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc

index 3e021b8..971e4ff 100644 (file)
@@ -95,10 +95,23 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
         "Cannot remove %s, neither its main input nor its output may be "
         "discarded",
         LogName(*passthru_op));
-    return false;
+    if (passthru_op->type != OperatorType::kTensorFlowReshape &&
+        model->GetArray(main_input_name).has_shape()) {
+      // We can't remove either array but we can remove the op. Converting it to
+      // a reshape gives us some hope of later on fixing that (either in the
+      // final runtime or as an additional fixup step).
+      //
+      // Note that we don't try to insert copies in place of reshapes as the
+      // copy itself is a trivial reshape and we'd go into an infinite loop!
+      transformation->AddMessageF("Replacing with a copy (reshape) instead");
+      InsertCopyOperator(model, main_input_name, output_name);
+    } else {
+      return false;
+    }
   }
 
   // Remove the pass-through node.
+  CHECK_EQ(passthru_it->get(), passthru_op);
   model->operators.erase(passthru_it);
 
   // Remove any array that is no longer used.