From 4f6074494d4bf77daac5749224017615bfca239f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 15:17:52 -0700 Subject: [PATCH] Move reorder-cast-and-transpose optimization to optimization stage. PiperOrigin-RevId: 198788352 --- .../grappler/optimizers/arithmetic_optimizer.cc | 154 +++++++++++++-------- .../grappler/optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 55 +++++--- 3 files changed, 133 insertions(+), 77 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 0edea16..ca3f84a 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -194,8 +194,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node); } -bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } - NodeDef* GetTailOfValuePreservingChain( const NodeDef& node, const NodeMap& node_map, const std::unordered_set& nodes_to_preserve) { @@ -1866,6 +1864,100 @@ class RemoveRedundantReshape : public ArithmeticOptimizerStage { } }; +// Reorder Cast and Transpose if beneficial. +// +// A common pattern after the layout optimizer is casting an uint8 NHWC +// image to float before transposing it to NCHW. It is beneficial to reorder +// the cast and the transpose to make the transpose process smaller amount +// of data. This optimization converts +// Transpose(Cast(image, dst_type), perm) +// to +// Cast(Transpose(image, perm), dst_type) +// when sizeof(image.type) < sizeof(dst_type). +// +// TODO(jingyue): This optimization can be generalized to a cast followed by +// a chain of ops that merely reorder elements (e.g. Reshape and +// DepthToSpace). +class ReorderCastAndTranspose : public ArithmeticOptimizerStage { + public: + explicit ReorderCastAndTranspose(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReorderCastAndTranspose", ctx, ctx_ext) {} + ~ReorderCastAndTranspose() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsTranspose(*node) && NodeIsOnCpuOrGpu(node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeDef* transpose = node; + + // Verify that input to Transpose is the Cast op. + NodeDef* cast; + TF_RETURN_IF_ERROR(GetInputNode(transpose->input(0), &cast)); + if (!IsCast(*cast)) return Status::OK(); + + // Input to the Cast-Transpose chain. + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(cast->input(0), &input)); + + const DataType src_type = GetSourceDataType(*cast); + const DataType dst_type = GetDestinationDataType(*cast); + + const string src_type_name = DataTypeString(src_type); + const string dst_type_name = DataTypeString(dst_type); + + // Check if nodes were not already optimized. + const string optimized_cast_name = + OptimizedNodeName(ParseNodeScopeAndName(cast->name()), dst_type_name); + const string optimized_transpose_name = OptimizedNodeName( + ParseNodeScopeAndName(transpose->name()), src_type_name); + + bool is_already_optimized = + ctx().node_map->NodeExists(optimized_transpose_name) || + ctx().node_map->NodeExists(optimized_cast_name); + + if (IsNumberType(src_type) && IsNumberType(dst_type) && + DataTypeSize(src_type) < DataTypeSize(dst_type) && + !is_already_optimized) { + NodeDef* new_transpose = AddCopyNode(optimized_transpose_name, transpose); + (*new_transpose->mutable_attr())["T"].set_type(src_type); + new_transpose->set_input(0, cast->input(0)); + + ctx().node_map->AddOutput(input->name(), new_transpose->name()); + ctx().node_map->AddOutput(NodeName(new_transpose->input(1)), + new_transpose->name()); + + NodeDef* new_cast = AddCopyNode(optimized_cast_name, cast); + new_cast->set_input(0, new_transpose->name()); + ctx().node_map->AddOutput(new_transpose->name(), new_cast->name()); + + AddToOptimizationQueue(new_transpose); + ForwardControlDependencies(new_transpose, {cast, node}); + + *simplified_node_name = new_cast->name(); + } + + return Status::OK(); + } + + private: + // This optimization can be dangerous on devices other than CPU and + // GPU. The transpose might not be implemented for image.type, or + // might be slower with image.type than with dst_type. + bool NodeIsOnCpuOrGpu(const NodeDef* node) const { + using str_util::StrContains; + + string task; + string device; + + return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && + (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU)); + } + + bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } +}; + } // namespace class UniqueNodes { @@ -2118,62 +2210,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector* nodes_to_simplify) { - - if (node->op() == "Transpose") { - // Reorder Cast and Transpose if beneficial. - // - // A common pattern after the layout optimizer is casting an uint8 NHWC - // image to float before transposing it to NCHW. It is beneficial to reorder - // the cast and the transpose to make the transpose process smaller amount - // of data. This optimization converts - // Transpose(Cast(image, dst_type), perm) - // to - // Cast(Transpose(image, perm), dst_type) - // when sizeof(image.type) < sizeof(dst_type). - // - // TODO(jingyue): This optimization can be generalized to a cast followed by - // a chain of ops that merely reorder elements (e.g. Reshape and - // DepthToSpace). - const NodeDef* transpose = node; - string dontcare; - string device; - // This optimization can be dangerous on devices other than CPU and GPU. The - // transpose might not be implemented for image.type, or might be slower - // with image.type than with dst_type. - if (DeviceNameUtils::SplitDeviceName(transpose->device(), &dontcare, - &device) && - (str_util::StrContains(device, DEVICE_CPU) || - str_util::StrContains(device, DEVICE_GPU))) { - const NodeDef* cast = node_map_->GetNode(transpose->input(0)); - if (cast->op() == "Cast") { - const NodeDef* input = node_map_->GetNode(cast->input(0)); - const DataType src_type = GetSourceDataType(*cast); - const DataType dst_type = GetDestinationDataType(*cast); - if (IsNumberType(src_type) && IsNumberType(dst_type) && - DataTypeSize(src_type) < DataTypeSize(dst_type) && - !OptimizedNodeExists(*cast, DataTypeString(dst_type)) && - !OptimizedNodeExists(*transpose, DataTypeString(src_type))) { - NodeDef* new_transpose = AddNode(*transpose, DataTypeString(src_type), - /*copy_node=*/true); - (*new_transpose->mutable_attr())["T"].set_type(src_type); - new_transpose->set_input(0, cast->input(0)); - node_map_->AddOutput(input->name(), new_transpose->name()); - node_map_->AddOutput(NodeName(new_transpose->input(1)), - new_transpose->name()); - - NodeDef* new_cast = - AddNode(*cast, DataTypeString(dst_type), /*copy_node=*/true); - new_cast->set_input(0, new_transpose->name()); - node_map_->AddOutput(new_transpose->name(), new_cast->name()); - - nodes_to_simplify->PushBack(new_transpose); - ForwardControlDependencies(new_transpose, {cast, node}); - return new_cast->name(); - } - } - } - } - // Fold a multiply of a scalar into the following convolution. This folding // can jump across nodes that merely reorders data (such as reshape and // transpose). For example, we can optimize @@ -2462,6 +2498,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.remove_logical_not) pipeline.AddStage(ctx, ctx_ext); + if (options_.reorder_cast_and_transpose) + pipeline.AddStage(ctx, ctx_ext); if (options_.hoist_cwise_unary_chains) pipeline.AddStage(ctx, ctx_ext); if (options_.convert_sqrt_div_to_rsqrt_mul) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 9f8ec85..0fce23a 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -72,6 +72,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; bool remove_redundant_reshape = true; + bool reorder_cast_and_transpose = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 43355ef..02f76df 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -97,12 +97,22 @@ class ArithmeticOptimizerTest : public GrapplerTest { } // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. + // Optionally run a constant folding pass before pruning. void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, - GraphDef* output) { + GraphDef* output, bool const_folding = false) { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); output->Clear(); TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + + if (const_folding) { + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr) + .Optimize(nullptr, *item, output)); + } + item->graph.Swap(output); output->Clear(); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); @@ -127,6 +137,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_redundant_reshape = false; options.remove_negation = false; options.remove_logical_not = false; + options.reorder_cast_and_transpose = false; optimizer->options_ = options; } @@ -179,6 +190,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_negation = true; } + void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.reorder_cast_and_transpose = true; + } + void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.hoist_cwise_unary_chains = true; @@ -1540,6 +1556,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { // => // Conv2D(Cast(Transpose(I)), W*S) tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); + Output inputs = ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3})); Output cast = ops::Cast(s, inputs, DT_FLOAT); @@ -1557,28 +1574,28 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true); - item.graph.Swap(&output); - TF_EXPECT_OK( - ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); + NodeMap node_map(&output); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + // Expected names for the optimized nodes. + const string p = "ArithmeticOptimizer/ReorderCastAndTranspose_"; + const string optimized_cast_name = strings::StrCat(p, "float_Cast"); + const string optimized_transpose_name = strings::StrCat(p, "uint8_Transpose"); - NodeMap node_map(&output); - const NodeDef* inputs_node = CHECK_NOTNULL(node_map.GetNode("Placeholder")); - const NodeDef* transpose_node = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("Transpose_uint8"))); - const NodeDef* cast_node = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_float"))); + const NodeDef* inputs_node = node_map.GetNode("Placeholder"); + const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name); + const NodeDef* cast_node = node_map.GetNode(optimized_cast_name); const NodeDef* weights_node = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D"))); - const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); + node_map.GetNode(OptimizedName("weights_scaled_Conv2D")); + const NodeDef* conv_node = node_map.GetNode("Conv2D"); + + ASSERT_TRUE(inputs_node != nullptr); + ASSERT_TRUE(transpose_node != nullptr); + ASSERT_TRUE(cast_node != nullptr); + ASSERT_TRUE(weights_node != nullptr); + ASSERT_TRUE(conv_node != nullptr); EXPECT_EQ(output.node_size(), 7); EXPECT_EQ(transpose_node->input(0), inputs_node->name()); -- 2.7.4