Support reduction with true keep_dims and squeeze along NHW dimensions.
authorYao Zhang <yaozhang@google.com>
Mon, 12 Feb 2018 19:54:07 +0000 (11:54 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Feb 2018 19:58:11 +0000 (11:58 -0800)
PiperOrigin-RevId: 185411786

tensorflow/core/grappler/optimizers/layout_optimizer.cc
tensorflow/python/grappler/layout_optimizer_test.py

index 5a62b77..a606f97 100644 (file)
@@ -1717,13 +1717,28 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
 
  protected:
   bool ShouldProcess() const override {
-    return !MustPreserve() && IsPortZeroDimsN(*node_, 2) && HasOutputs() &&
-           IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW() &&
-           IsOnGPU();
+    bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
+                             (IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
+    return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+           IsInputConvertible() && is_dims_supported && IsOnGPU();
   }
 
   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
 
+  Status CustomizedProcessing() override {
+    TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
+    auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
+    if (list->i_size() == 2) {
+      list->set_i(0, 2);
+      list->set_i(1, 3);
+    } else if (list->i_size() == 3) {
+      list->set_i(1, 2);
+      list->set_i(2, 3);
+    }
+    return Status::OK();
+  }
+
+ private:
   bool IsInputConvertible() const {
     int input_port;
     auto input = node_map_->GetNode(node_->input(0));
@@ -1736,33 +1751,31 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
       if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
         return true;
       }
+      if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 &&
+          shape.dim(2).size() == 1) {
+        return true;
+      }
     }
     return false;
   }
 
-  bool IsAlongDimHW() const {
+  bool IsAlongAxis(const std::vector<int>& axis) const {
     if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
       auto list = node_->attr().at("squeeze_dims").list();
       // If list is empty, Squeeze op will squeeze all dimensions of size 1.
       if (list.i_size() == 0) return true;
-      if (list.i_size() == 2) {
-        if (list.i(0) == 1 && list.i(1) == 2) {
-          return true;
+      if (list.i_size() == axis.size()) {
+        bool along_axis = true;
+        for (int i = 0; i < axis.size(); i++) {
+          along_axis = along_axis && (list.i(i) == axis[i]);
         }
+        if (along_axis) return true;
       }
     }
     return false;
   }
-
-  Status CustomizedProcessing() override {
-    TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
-    auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
-    if (list->i_size() == 2) {
-      list->set_i(0, 2);
-      list->set_i(1, 3);
-    }
-    return Status::OK();
-  }
+  bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
+  bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
 };
 
 class ReduceProcessor : public AgnosticNodeProcessor {
@@ -1789,12 +1802,18 @@ class ReduceProcessor : public AgnosticNodeProcessor {
     return Status::OK();
   }
 
-  Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
+  Status AddLayoutTransposeToOutputs() override {
+    if (KeepDims()) {
+      return AddTransformToOutputs("Transpose");
+    }
+    return Status::OK();
+  }
 
  private:
   bool IsReduceAxisSupported() const {
-    return IsAlongAllFourDims() || IsAlongHWC() ||
-           ((IsAlongNHW() || IsAlongHW() || IsAlongC()) && !KeepDims());
+    return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() ||
+                           IsAlongNHW() || IsAlongHW() || IsAlongC()) &&
+                          !KeepDims());
   }
 
   bool IsAlongAxis(const std::vector<int>& axis) const {
index 30dcdf3..b04bbb0 100644 (file)
@@ -471,6 +471,66 @@ class LayoutOptimizerTest(test.TestCase):
       self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
+  def testSqueezeAlongHW(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2], keep_dims=True)
+      squeeze = array_ops.squeeze(reduce_sum, axis=[1, 2])
+      output = array_ops.identity(squeeze)
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if _is_transpose(node.name):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      # Three transposes were initially added in the Expand phase of
+      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+      expected_num_transposes = 1
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
+  def testSqueezeAlongNHW(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2], keep_dims=True)
+      squeeze = array_ops.squeeze(reduce_sum, axis=[0, 1, 2])
+      output = array_ops.identity(squeeze)
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if _is_transpose(node.name):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      # Three transposes were initially added in the Expand phase of
+      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+      expected_num_transposes = 1
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
   def testReduceSumAlongHWC(self):
     if test.is_gpu_available(cuda_only=True):
       random_seed.set_random_seed(0)
@@ -558,6 +618,36 @@ class LayoutOptimizerTest(test.TestCase):
       self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
+  def testReduceSumAlongCKeepDims(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      reduce_sum = math_ops.reduce_sum(conv, axis=[3], keep_dims=True)
+      output = array_ops.identity(reduce_sum)
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if _is_transpose(node.name):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      # Four transposes were initially added in the Expand phase of
+      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+      expected_num_transposes = 2
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+      self._assert_trans_nchw_to_nhwc('Sum-0-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
   def testConcatWithControlDependency(self):
     if test.is_gpu_available(cuda_only=True):
       random_seed.set_random_seed(0)