Do not convert layout for Select if condition input is of unknown shape.
authorYao Zhang <yaozhang@google.com>
Sat, 10 Feb 2018 09:45:11 +0000 (01:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 10 Feb 2018 09:48:49 +0000 (01:48 -0800)
PiperOrigin-RevId: 185242138

tensorflow/core/grappler/optimizers/layout_optimizer.cc
tensorflow/python/grappler/layout_optimizer_test.py

index e1a2d65..5a62b77 100644 (file)
@@ -1428,7 +1428,6 @@ class IdentityNProcessor : public AgnosticNodeProcessor {
   std::vector<int> GetInputPos() const override { return input_pos_; }
 
   std::set<int> GetOutputPos() const override {
-    std::vector<int> input_poses;
     std::set<int> output_pos{};
     for (const auto& input_pos : input_pos_) {
       output_pos.insert(input_pos);
@@ -1572,6 +1571,16 @@ class SelectProcessor : public AgnosticNodeProcessor {
       : AgnosticNodeProcessor(opt_cxt) {}
 
  protected:
+  bool ShouldProcess() const override {
+    auto input0 = node_map_->GetNode(node_->input(0));
+    int input0_port;
+    ParseNodeName(node_->input(0), &input0_port);
+    bool is_input0_scalar_vector_4d = IsPortDimsN(*input0, input0_port, 0) ||
+                                      IsPortDimsN(*input0, input0_port, 1) ||
+                                      IsPortDimsN(*input0, input0_port, 4);
+    return AgnosticNodeProcessor::ShouldProcess() && is_input0_scalar_vector_4d;
+  }
+
   std::vector<int> GetInputPos() const override {
     auto input0 = node_map_->GetNode(node_->input(0));
     int input0_port;
index 25b1cdc..30dcdf3 100644 (file)
@@ -771,6 +771,37 @@ class LayoutOptimizerTest(test.TestCase):
       self._assert_trans_nchw_to_nhwc('Select-0-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
+  def testSelectOpConditionUnknownShape(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      add = math_ops.add(conv, conv)
+      condition = array_ops.placeholder(dtype='bool')
+      select = gen_math_ops._select(condition, conv, add)
+      output = array_ops.identity(select)
+
+      condition_val = np.zeros((1, 7, 7, 64))
+      with session.Session() as sess:
+        output_val_ref = sess.run(output, feed_dict={condition: condition_val})
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(
+            output, run_metadata=metadata, feed_dict={condition: condition_val})
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if _is_transpose(node.name):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      expected_num_transposes = 3
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
   def testSelectOpScalarCondition(self):
     if test.is_gpu_available(cuda_only=True):
       random_seed.set_random_seed(0)