Simplify the remapper code and added support for non scalar mean, variance, scale...
authorBenoit Steiner <bsteiner@google.com>
Wed, 23 May 2018 23:45:26 +0000 (16:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 23:48:01 +0000 (16:48 -0700)
PiperOrigin-RevId: 197812268

tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/remapper.cc
tensorflow/core/grappler/optimizers/remapper_test.cc
tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc

index 104a042..f686069 100644 (file)
@@ -695,6 +695,7 @@ tf_cuda_cc_test(
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler:devices",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
index 2a62871..efd870b 100644 (file)
@@ -28,10 +28,71 @@ namespace grappler {
 
 void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
   const string& x = fused_node.input(0);
-  const string& scale = fused_node.input(1);
-  const string& offset = fused_node.input(2);
-  const string& mean = fused_node.input(3);
-  const string& variance = fused_node.input(4);
+  string scale = fused_node.input(1);
+  string offset = fused_node.input(2);
+  string mean = fused_node.input(3);
+  string variance = fused_node.input(4);
+
+  if (fused_node.attr().at("data_format").s() == "NCHW") {
+    // Need to reshape the last 4 inputs
+    NodeDef* new_shape = optimized_graph->add_node();
+    new_shape->set_name(AddPrefixToNodeName("NCHWShape", fused_node.name()));
+    new_shape->set_op("Const");
+    new_shape->set_device(fused_node.device());
+    *new_shape->add_input() = AsControlDependency(scale);
+    (*new_shape->mutable_attr())["dtype"].set_type(DT_INT32);
+    Tensor t(DT_INT32, {4});
+    t.flat<int32>()(0) = 1;
+    t.flat<int32>()(1) = -1;
+    t.flat<int32>()(2) = 1;
+    t.flat<int32>()(3) = 1;
+    t.AsProtoTensorContent(
+        (*new_shape->mutable_attr())["value"].mutable_tensor());
+
+    NodeDef* reshaped_scale = optimized_graph->add_node();
+    reshaped_scale->set_name(
+        AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
+    reshaped_scale->set_op("Reshape");
+    reshaped_scale->set_device(fused_node.device());
+    *reshaped_scale->add_input() = scale;
+    *reshaped_scale->add_input() = new_shape->name();
+    (*reshaped_scale->mutable_attr())["T"] = fused_node.attr().at("T");
+    (*reshaped_scale->mutable_attr())["Tshape"].set_type(DT_INT32);
+    scale = reshaped_scale->name();
+
+    NodeDef* reshaped_offset = optimized_graph->add_node();
+    reshaped_offset->set_name(
+        AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
+    reshaped_offset->set_op("Reshape");
+    reshaped_offset->set_device(fused_node.device());
+    *reshaped_offset->add_input() = offset;
+    *reshaped_offset->add_input() = new_shape->name();
+    (*reshaped_offset->mutable_attr())["T"] = fused_node.attr().at("T");
+    (*reshaped_offset->mutable_attr())["Tshape"].set_type(DT_INT32);
+    offset = reshaped_offset->name();
+
+    NodeDef* reshaped_mean = optimized_graph->add_node();
+    reshaped_mean->set_name(
+        AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
+    reshaped_mean->set_op("Reshape");
+    reshaped_mean->set_device(fused_node.device());
+    *reshaped_mean->add_input() = mean;
+    *reshaped_mean->add_input() = new_shape->name();
+    (*reshaped_mean->mutable_attr())["T"] = fused_node.attr().at("T");
+    (*reshaped_mean->mutable_attr())["Tshape"].set_type(DT_INT32);
+    mean = reshaped_mean->name();
+
+    NodeDef* reshaped_variance = optimized_graph->add_node();
+    reshaped_variance->set_name(
+        AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
+    reshaped_variance->set_op("Reshape");
+    reshaped_variance->set_device(fused_node.device());
+    *reshaped_variance->add_input() = variance;
+    *reshaped_variance->add_input() = new_shape->name();
+    (*reshaped_variance->mutable_attr())["T"] = fused_node.attr().at("T");
+    (*reshaped_variance->mutable_attr())["Tshape"].set_type(DT_INT32);
+    variance = reshaped_variance->name();
+  }
 
   float epsilon = 0.0f;
   if (fused_node.attr().count("epsilon")) {
@@ -118,20 +179,16 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
       optimizable &= (node.attr().count("is_training") == 0 ||
                       !node.attr().at("is_training").b());
       if (optimizable) {
-        std::unordered_set<int> const_inputs;
-        for (const string& input : node.input()) {
-          int pos;
-          const string input_node = ParseNodeName(input, &pos);
-          if (properties.HasInputProperties(input_node)) {
-            const auto& props = properties.GetInputProperties(input_node);
-            if (props.size() > pos && props[pos].has_value()) {
-              const_inputs.insert(pos);
-            }
+        int const_inputs = 0;
+        const auto& props = properties.GetInputProperties(node.name());
+        for (const auto& prop : props) {
+          if (prop.has_value()) {
+            const_inputs += 1;
           }
         }
         // TODO(bsteiner): use the cost model to compare the cost of fused batch
         // norm against that of the optimized form.
-        optimizable = (const_inputs.size() >= 4);
+        optimizable = (const_inputs >= 4);
       }
       if (optimizable) {
         for (GraphView::Edge edge : graph.GetFanoutEdges(node, false)) {
@@ -143,6 +200,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
         }
       }
       if (optimizable) {
+        std::cout << "Optimizing fused batch norm node " << node.DebugString()
+                  << std::endl;
         AddBatchNormNodes(optimized_graph, node);
         continue;
       }
index 291585c..4cbf0d8 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/optimizers/remapper.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/devices.h"
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/utils/grappler_test.h"
 #include "tensorflow/core/platform/test.h"
@@ -54,5 +55,41 @@ TEST_F(RemapperTest, FusedBatchNorm) {
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
+TEST_F(RemapperTest, FusedBatchNormNCHW) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output dflt =
+      ops::Const(s.WithOpName("dflt"), {3.14f, 2.7f, 1.0f, 2.0f, 3.0f, 100.0f},
+                 {1, 3, 1, 2});
+  Output x = ops::PlaceholderWithDefault(s.WithOpName("x"), dflt, {1, 3, 1, 2});
+  Output scale = ops::Const(s.WithOpName("scale"), {0.3f, 7.0f, 123.0f}, {3});
+  Output offset =
+      ops::Const(s.WithOpName("offset"), {0.123f, 2.1f, 0.55f}, {3});
+  Output mean = ops::Const(s.WithOpName("mean"), {7.3f, 8.3f, 3.1f}, {3});
+  Output variance =
+      ops::Const(s.WithOpName("variance"), {0.57f, 1.0f, 2.0f}, {3});
+  ops::FusedBatchNorm::Attrs attr;
+  attr = attr.IsTraining(false);
+  attr = attr.DataFormat("NCHW");
+  ops::FusedBatchNorm bn(s.WithOpName("batch_norm").WithDevice("/device:GPU:0"),
+                         x, scale, offset, mean, variance, attr);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"batch_norm"};
+
+  Remapper optimizer(RewriterConfig::ON);
+  GraphDef output;
+  TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
+
+  if (GetNumAvailableGPUs() > 0) {
+    // NCHW batch norm is only supported on GPU.
+    auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+    EXPECT_EQ(1, tensors_expected.size());
+    auto tensors = EvaluateNodes(output, item.fetch);
+    EXPECT_EQ(1, tensors.size());
+    test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+  }
+}
+
 }  // namespace grappler
 }  // namespace tensorflow
index 7651a03..435f46c 100644 (file)
@@ -191,7 +191,7 @@ class FoldOldBatchNormsTest : public ::testing::Test {
     std::vector<Tensor> fused_outputs;
     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
 
-    test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+    test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5);
 
     for (const NodeDef& node : fused_graph_def.node()) {
       EXPECT_NE("FusedBatchNorm", node.op());