Don't bypass reshape nodes that anchor control dependencies
authorBenoit Steiner <bsteiner@google.com>
Mon, 2 Apr 2018 20:53:52 +0000 (13:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 2 Apr 2018 20:56:28 +0000 (13:56 -0700)
PiperOrigin-RevId: 191342646

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/utils.cc
tensorflow/core/grappler/utils.h

index 882e4d9..6e27259 100644 (file)
@@ -1344,19 +1344,18 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
     int output_pos = 0;
     string input_node_name = ParseNodeName(node->input(0), &output_pos);
     const NodeDef* input = node_map_->GetNode(input_node_name);
-    if (input->op() == "Reshape") {
+    if (input->op() == "Reshape" && !HasControlInputs(*input)) {
       reshape->set_input(0, input->input(0));
       node_map_->UpdateInput(reshape->name(), input->name(), input->input(0));
       nodes_to_simplify->PushBack(reshape);
       return reshape->name();
     }
 
-    // If the reshape is a no-op, forward its input to its consumers. This is
-    // considered aggressive, because users may state that the placeholder
-    // outputs tensors of shape [M, N] while feeding it with tensors of shape
-    // [M*N] (or worse). The reshape nodes are then necessary to update the
-    // tensor metadata to the required shape.
-    if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_)) {
+    // If the reshape is a no-op, forward its input to its consumers, unless it
+    // anchors a control dependency since we want to make sure that control
+    // dependency is triggered.
+    if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) &&
+        !HasControlInputs(*reshape)) {
       return reshape->input(0);
     }
   }
index 86a6d50..5893f28 100644 (file)
@@ -255,6 +255,14 @@ int NumOutputs(const NodeDef& node, GraphDef* graph) {
   return num_outputs;
 }
 
+bool HasControlInputs(const NodeDef& node) {
+  int num_inputs = node.input_size();
+  if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
+    return true;
+  }
+  return false;
+}
+
 int NumNonControlInputs(const NodeDef& node) {
   int num_inputs = node.input_size();
   for (const string& input : node.input()) {
index 7aa3193..11555d7 100644 (file)
@@ -138,6 +138,9 @@ string AsControlDependency(const string& node);
 // some of the outputs may be unconnected.
 int NumOutputs(const NodeDef& node, GraphDef* graph);
 
+// Returns true iff the node has at least one control input.
+bool HasControlInputs(const NodeDef& node);
+
 // Number of connected non-control inputs.
 int NumNonControlInputs(const NodeDef& node);