Fix a bug in updating NodeMap, where the node name without port number should
authorYao Zhang <yaozhang@google.com>
Thu, 11 Jan 2018 00:13:09 +0000 (16:13 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 11 Jan 2018 00:16:18 +0000 (16:16 -0800)
have been used.

PiperOrigin-RevId: 181532901

tensorflow/core/grappler/optimizers/layout_optimizer.cc
tensorflow/python/grappler/layout_optimizer_test.py

index 870b5289b572e7d4b28be1e87f70e28dddbcc5cd..ea7b05d3810f7a4b9f6388e040df930526f6e47e 100644 (file)
@@ -625,7 +625,8 @@ class NodeProcessor : public GraphProcessor {
           node_name, node_->input(pos), const_name, dtype,
           input_node->attr().at("_output_shapes").list().shape(output_pos),
           true);
-      node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
+      node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
+                              node_name);
       node_map_->AddOutput(node_name, node_->name());
       *node_->mutable_input(pos) = node_name;
     }
@@ -917,7 +918,7 @@ class NodeProcessor : public GraphProcessor {
     auto added_node =
         AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
     *node_->mutable_input(input_pos) = added_node->name();
-    node_map_->UpdateOutput(added_node->input(0), node_->name(),
+    node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
                             added_node->name());
     node_map_->AddOutput(added_node->name(), node_->name());
   }
@@ -1328,8 +1329,8 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
       AddNodeReshape(reshape_node_name, node_->input(vector_index),
                      shape_const_node_name, node_->attr().at("T").type());
       node_map_->AddOutput(shape_const_node_name, reshape_node_name);
-      node_map_->UpdateOutput(node_->input(vector_index), node_->name(),
-                              reshape_node_name);
+      node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
+                              node_->name(), reshape_node_name);
       node_map_->AddOutput(reshape_node_name, node_->name());
       *node_->mutable_input(vector_index) = reshape_node_name;
     }
index 4fdd779ddd49f2c6e50e6f9964cd42e4447a928d..25c5ef6b68452c0b8f8dc67a15187db1df5e3934 100644 (file)
@@ -283,7 +283,12 @@ class LayoutOptimizerTest(test.TestCase):
       conv = _two_layer_model(x)
       dim = array_ops.placeholder(dtype='int32')
       split = array_ops.split(conv, 2, axis=dim)
-      output = math_ops.reduce_sum(split[0])
+      scale = constant_op.constant(0.1, shape=[32])
+      offset = constant_op.constant(0.3, shape=[32])
+      bn0 = nn.fused_batch_norm(split[0], scale, offset)
+      bn1 = nn.fused_batch_norm(split[1], scale, offset)
+      add = bn0[0] + bn1[0]
+      output = array_ops.identity(add)
 
       with session.Session() as sess:
         output_val_ref = sess.run(output, feed_dict={dim: 3})
@@ -299,12 +304,10 @@ class LayoutOptimizerTest(test.TestCase):
           num_transposes += 1
         nodes.append(node.name)
 
-      # Four transposes were initially added in the Expand phase of
-      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
       expected_num_transposes = 2
       self.assertEqual(expected_num_transposes, num_transposes)
       self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
-      self._assert_trans_nchw_to_nhwc('split-0-0', nodes)
+      self._assert_trans_nchw_to_nhwc('add_2-0-0', nodes)
       self._assert_map_nhwc_to_nchw('split-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)