Support scalar and vector condition for select.
authorYao Zhang <yaozhang@google.com>
Thu, 4 Jan 2018 17:45:53 +0000 (09:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 4 Jan 2018 17:50:08 +0000 (09:50 -0800)
PiperOrigin-RevId: 180809175

tensorflow/core/grappler/optimizers/layout_optimizer.cc
tensorflow/python/grappler/layout_optimizer_test.py

index 37ab46ffb1f849cc3e5265aab0f5a0f57a9b212a..37610d2857a58c632c602ef208e65354243fec10 100644 (file)
@@ -1566,6 +1566,26 @@ class TernaryOpProcessor : public AgnosticNodeProcessor {
   std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
 };
 
+class SelectProcessor : public AgnosticNodeProcessor {
+ public:
+  explicit SelectProcessor(const OptimizeContext& opt_cxt)
+      : AgnosticNodeProcessor(opt_cxt) {}
+
+ protected:
+  std::vector<int> GetInputPos() const override {
+    auto input0 = node_map_->GetNode(node_->input(0));
+    int input0_port;
+    ParseNodeName(node_->input(0), &input0_port);
+    // Input 0 could be a scalar, a vector with size matching the first
+    // dimension of input 1 and 2, or must have the same shape as input 1 and 2.
+    if (IsPortDimsFour(*input0, input0_port)) {
+      return {0, 1, 2};
+    } else {
+      return {1, 2};
+    }
+  }
+};
+
 class UnaryGradProcessor : public AgnosticNodeProcessor {
  public:
   explicit UnaryGradProcessor(const OptimizeContext& opt_cxt)
@@ -1874,7 +1894,7 @@ class DataLayoutOptimizer : GraphProcessor {
           std::unique_ptr<NodeProcessor> node_processor;
           if (IsAddN(*node)) {
             node_processor.reset(new AddNProcessor(opt_cxt));
-          } else if (IsBetainc(*node) || IsSelect(*node)) {
+          } else if (IsBetainc(*node)) {
             node_processor.reset(new TernaryOpProcessor(opt_cxt));
           } else if (IsBinaryOp(*node)) {
             node_processor.reset(new BinaryOpProcessor(opt_cxt));
@@ -1895,6 +1915,8 @@ class DataLayoutOptimizer : GraphProcessor {
             node_processor.reset(new ReduceProcessor(opt_cxt));
           } else if (IsReverseV2(*node)) {
             node_processor.reset(new ReverseProcessor(opt_cxt));
+          } else if (IsSelect(*node)) {
+            node_processor.reset(new SelectProcessor(opt_cxt));
           } else if (IsSlice(*node)) {
             node_processor.reset(new SliceProcessor(opt_cxt));
           } else if (IsStridedSlice(*node)) {
index 68d7282cedd426e36e8abeb0437bd1d9a4131bcd..487f1b0f7a36c54b6a29e43eae864866519c36b6 100644 (file)
@@ -562,7 +562,7 @@ class LayoutOptimizerTest(test.TestCase):
       self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_ReverseV2_1', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
-  def testTernaryOp(self):
+  def testSelectOp(self):
     if test.is_gpu_available(cuda_only=True):
       random_seed.set_random_seed(0)
       x = random_ops.truncated_normal([1, 784], seed=0)
@@ -593,6 +593,36 @@ class LayoutOptimizerTest(test.TestCase):
       self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Select-0-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
+  def testSelectOpScalarCondition(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      add = math_ops.add(conv, conv)
+      condition = constant_op.constant(True)
+      select = gen_math_ops._select(condition, conv, add)
+      output = array_ops.identity(select)
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if node.name.startswith('LayoutOptimizerTranspose'):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      expected_num_transposes = 2
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+      self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Select-0-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
   def testPadWithNonConstPaddings(self):
     if test.is_gpu_available(cuda_only=True):
       random_seed.set_random_seed(0)