Add missing update of node map in the Mul(x,x) => Square(x) rewrite. This is what...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 May 2018 22:58:28 +0000 (15:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 23:02:11 +0000 (16:02 -0700)
PiperOrigin-RevId: 196043455

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc

index adfae2e..f46c30c 100644 (file)
@@ -2233,6 +2233,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
         new_square_node->set_input(i - 1, new_square_node->input(i));
       }
       new_square_node->mutable_input()->RemoveLast();
+      for (const string& input : new_square_node->input()) {
+        node_map_->AddOutput(NodeName(input), new_square_node->name());
+      }
       return new_square_node->name();
     }
   }