From dac1f124020234fe24e8893a981b15395d0c6de8 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 23 May 2018 16:45:26 -0700 Subject: [PATCH] Simplify the remapper code and added support for non scalar mean, variance, scale and offset. PiperOrigin-RevId: 197812268 --- tensorflow/core/grappler/optimizers/BUILD | 1 + tensorflow/core/grappler/optimizers/remapper.cc | 87 ++++++++++++++++++---- .../core/grappler/optimizers/remapper_test.cc | 37 +++++++++ .../graph_transforms/fold_old_batch_norms_test.cc | 2 +- 4 files changed, 112 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 104a042..f686069 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -695,6 +695,7 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/grappler:devices", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 2a62871..efd870b 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -28,10 +28,71 @@ namespace grappler { void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) { const string& x = fused_node.input(0); - const string& scale = fused_node.input(1); - const string& offset = fused_node.input(2); - const string& mean = fused_node.input(3); - const string& variance = fused_node.input(4); + string scale = fused_node.input(1); + string offset = fused_node.input(2); + string mean = fused_node.input(3); + string variance = fused_node.input(4); + + if (fused_node.attr().at("data_format").s() == "NCHW") { + // Need to reshape the last 4 inputs + NodeDef* new_shape = optimized_graph->add_node(); + new_shape->set_name(AddPrefixToNodeName("NCHWShape", fused_node.name())); + new_shape->set_op("Const"); + new_shape->set_device(fused_node.device()); + *new_shape->add_input() = AsControlDependency(scale); + (*new_shape->mutable_attr())["dtype"].set_type(DT_INT32); + Tensor t(DT_INT32, {4}); + t.flat()(0) = 1; + t.flat()(1) = -1; + t.flat()(2) = 1; + t.flat()(3) = 1; + t.AsProtoTensorContent( + (*new_shape->mutable_attr())["value"].mutable_tensor()); + + NodeDef* reshaped_scale = optimized_graph->add_node(); + reshaped_scale->set_name( + AddPrefixToNodeName("NCHWShapedScale", fused_node.name())); + reshaped_scale->set_op("Reshape"); + reshaped_scale->set_device(fused_node.device()); + *reshaped_scale->add_input() = scale; + *reshaped_scale->add_input() = new_shape->name(); + (*reshaped_scale->mutable_attr())["T"] = fused_node.attr().at("T"); + (*reshaped_scale->mutable_attr())["Tshape"].set_type(DT_INT32); + scale = reshaped_scale->name(); + + NodeDef* reshaped_offset = optimized_graph->add_node(); + reshaped_offset->set_name( + AddPrefixToNodeName("NCHWShapedOffset", fused_node.name())); + reshaped_offset->set_op("Reshape"); + reshaped_offset->set_device(fused_node.device()); + *reshaped_offset->add_input() = offset; + *reshaped_offset->add_input() = new_shape->name(); + (*reshaped_offset->mutable_attr())["T"] = fused_node.attr().at("T"); + (*reshaped_offset->mutable_attr())["Tshape"].set_type(DT_INT32); + offset = reshaped_offset->name(); + + NodeDef* reshaped_mean = optimized_graph->add_node(); + reshaped_mean->set_name( + AddPrefixToNodeName("NCHWShapedMean", fused_node.name())); + reshaped_mean->set_op("Reshape"); + reshaped_mean->set_device(fused_node.device()); + *reshaped_mean->add_input() = mean; + *reshaped_mean->add_input() = new_shape->name(); + (*reshaped_mean->mutable_attr())["T"] = fused_node.attr().at("T"); + (*reshaped_mean->mutable_attr())["Tshape"].set_type(DT_INT32); + mean = reshaped_mean->name(); + + NodeDef* reshaped_variance = optimized_graph->add_node(); + reshaped_variance->set_name( + AddPrefixToNodeName("NCHWShapedVariance", fused_node.name())); + reshaped_variance->set_op("Reshape"); + reshaped_variance->set_device(fused_node.device()); + *reshaped_variance->add_input() = variance; + *reshaped_variance->add_input() = new_shape->name(); + (*reshaped_variance->mutable_attr())["T"] = fused_node.attr().at("T"); + (*reshaped_variance->mutable_attr())["Tshape"].set_type(DT_INT32); + variance = reshaped_variance->name(); + } float epsilon = 0.0f; if (fused_node.attr().count("epsilon")) { @@ -118,20 +179,16 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, optimizable &= (node.attr().count("is_training") == 0 || !node.attr().at("is_training").b()); if (optimizable) { - std::unordered_set const_inputs; - for (const string& input : node.input()) { - int pos; - const string input_node = ParseNodeName(input, &pos); - if (properties.HasInputProperties(input_node)) { - const auto& props = properties.GetInputProperties(input_node); - if (props.size() > pos && props[pos].has_value()) { - const_inputs.insert(pos); - } + int const_inputs = 0; + const auto& props = properties.GetInputProperties(node.name()); + for (const auto& prop : props) { + if (prop.has_value()) { + const_inputs += 1; } } // TODO(bsteiner): use the cost model to compare the cost of fused batch // norm against that of the optimized form. - optimizable = (const_inputs.size() >= 4); + optimizable = (const_inputs >= 4); } if (optimizable) { for (GraphView::Edge edge : graph.GetFanoutEdges(node, false)) { @@ -143,6 +200,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, } } if (optimizable) { + std::cout << "Optimizing fused batch norm node " << node.DebugString() + << std::endl; AddBatchNormNodes(optimized_graph, node); continue; } diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 291585c..4cbf0d8 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/remapper.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/platform/test.h" @@ -54,5 +55,41 @@ TEST_F(RemapperTest, FusedBatchNorm) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } +TEST_F(RemapperTest, FusedBatchNormNCHW) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output dflt = + ops::Const(s.WithOpName("dflt"), {3.14f, 2.7f, 1.0f, 2.0f, 3.0f, 100.0f}, + {1, 3, 1, 2}); + Output x = ops::PlaceholderWithDefault(s.WithOpName("x"), dflt, {1, 3, 1, 2}); + Output scale = ops::Const(s.WithOpName("scale"), {0.3f, 7.0f, 123.0f}, {3}); + Output offset = + ops::Const(s.WithOpName("offset"), {0.123f, 2.1f, 0.55f}, {3}); + Output mean = ops::Const(s.WithOpName("mean"), {7.3f, 8.3f, 3.1f}, {3}); + Output variance = + ops::Const(s.WithOpName("variance"), {0.57f, 1.0f, 2.0f}, {3}); + ops::FusedBatchNorm::Attrs attr; + attr = attr.IsTraining(false); + attr = attr.DataFormat("NCHW"); + ops::FusedBatchNorm bn(s.WithOpName("batch_norm").WithDevice("/device:GPU:0"), + x, scale, offset, mean, variance, attr); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"batch_norm"}; + + Remapper optimizer(RewriterConfig::ON); + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + if (GetNumAvailableGPUs() > 0) { + // NCHW batch norm is only supported on GPU. + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index 7651a03..435f46c 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -191,7 +191,7 @@ class FoldOldBatchNormsTest : public ::testing::Test { std::vector fused_outputs; TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); - test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 2e-5); for (const NodeDef& node : fused_graph_def.node()) { EXPECT_NE("FusedBatchNorm", node.op()); -- 2.7.4