Add node types for DFS traversal to catch more issues with deduping inputs to in...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Mar 2018 22:47:23 +0000 (15:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 22:50:26 +0000 (15:50 -0700)
PiperOrigin-RevId: 190687820

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc

index 23e2185..5dd0b6f 100644 (file)
@@ -1089,7 +1089,8 @@ namespace {
 
 bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) {
   const std::unordered_set<string> op_types_to_traverse = {
-      node.op(), "Identity", "IdentityN", "Reshape"};
+      node.op(),    "Identity", "IdentityN", "Reshape",
+      "ExpandDims", "Enter",    "Switch",    "Merge"};
   int node_idx = graph_view.index(node.name());
   std::set<int> node_fanout;
   graph_view.DepthFirstSearch(op_types_to_traverse, node_idx, &node_fanout);