From 72b1a058613c26938a57670b3f32e29ba0e58d23 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Thu, 8 Feb 2018 23:17:54 -0800 Subject: [PATCH] Only convert format if input is of layout-agnostic type. PiperOrigin-RevId: 185103227 --- .../core/grappler/optimizers/layout_optimizer.cc | 43 ++++++++++++++-------- .../python/grappler/layout_optimizer_test.py | 31 ++++++++++++++++ 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 433b356..e1a2d65 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1400,39 +1400,44 @@ class HistogramSummaryProcessor : public AgnosticNodeProcessor { class IdentityNProcessor : public AgnosticNodeProcessor { public: explicit IdentityNProcessor(const OptimizeContext& opt_cxt) - : AgnosticNodeProcessor(opt_cxt) {} - - protected: - bool ShouldProcess() const override { - return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && - IsOnGPU(); - } - - std::vector GetInputPos() const override { - std::vector input_pos; + : AgnosticNodeProcessor(opt_cxt) { + std::set ops_format_agnostic = GetOpsFormatAgnostic(); for (int i = 0; i < node_->input_size(); i++) { auto input = node_map_->GetNode(node_->input(i)); int port; ParseNodeName(node_->input(i), &port); // Skip control input. if (port != -1) { + bool is_agnostic = + ops_format_agnostic.find(input->op()) != ops_format_agnostic.end(); if (IsPortDimsFour(*input, port) && - (IsNodeAfterNCHWToNHWC(*input) || + ((IsNodeAfterNCHWToNHWC(*input) && is_agnostic) || IsTransposeNCHWToNHWC(input->name()))) { - input_pos.push_back(i); + input_pos_.push_back(i); } } } - return input_pos; } + protected: + bool ShouldProcess() const override { + return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && + IsOnGPU(); + } + + std::vector GetInputPos() const override { return input_pos_; } + std::set GetOutputPos() const override { + std::vector input_poses; std::set output_pos{}; - for (const auto& input_pos : GetInputPos()) { + for (const auto& input_pos : input_pos_) { output_pos.insert(input_pos); } return output_pos; } + + private: + std::vector input_pos_; }; class ShapeProcessor : public IdentityNProcessor { @@ -1471,10 +1476,16 @@ class MergeProcessor : public AgnosticNodeProcessor { private: bool IsEveryInputAfterNCHWToNHWC() const { + std::set ops_format_agnostic = GetOpsFormatAgnostic(); for (const auto& input : node_->input()) { auto input_node = node_map_->GetNode(input); - if (IsNodeAfterNCHWToNHWC(*input_node) || - IsTransposeNCHWToNHWC(input_node->name())) { + int port; + ParseNodeName(input, &port); + bool is_agnostic = ops_format_agnostic.find(input_node->op()) != + ops_format_agnostic.end(); + if (IsPortDimsFour(*input_node, port) && + ((IsNodeAfterNCHWToNHWC(*input_node) && is_agnostic) || + IsTransposeNCHWToNHWC(input_node->name()))) { continue; } return false; diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 5bc9e4b..25b1cdc 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -1127,6 +1127,37 @@ class LayoutOptimizerTest(test.TestCase): self._assert_vec_nchw_to_nhwc('ShapeN-0-0', nodes) self.assertAllEqual(output_val_ref, output_val) + def testShapeNFollowedByNotConvertibleNodeReshape(self): + if test.is_gpu_available(cuda_only=True): + x = array_ops.placeholder(dtype='float32') + conv = _two_layer_model(x) + conv_reshape = array_ops.reshape(conv, [1, 1, 1, -1]) + shapen = array_ops.shape_n([conv, conv_reshape]) + shape = array_ops.identity(shapen[1]) + ones = array_ops.ones(shape) + output = math_ops.add_n([conv_reshape, ones]) + + x_val = [1.7] * 784 + with session.Session() as sess: + output_val_ref = sess.run(output, feed_dict={x: x_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={x: x_val}) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllEqual(output_val_ref, output_val) + def testLoop(self): if test.is_gpu_available(cuda_only=True): output = _loop() -- 2.7.4