From e8882f768127b71e03efbf193a9c3152ab84802a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 4 Apr 2018 14:45:05 -0700 Subject: [PATCH] GraphOptimizerStagePipeline to pass through multiple optimizer stages, skipping stages that return error. PiperOrigin-RevId: 191650182 --- .../grappler/optimizers/arithmetic_optimizer.cc | 63 ++++++++-------------- .../grappler/optimizers/graph_optimizer_stage.h | 61 +++++++++++++++++++++ 2 files changed, 83 insertions(+), 41 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 6e27259..919f23f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1667,34 +1667,24 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { &frame_map_); const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify); - std::vector> stages; - - if (options_.combine_add_to_addn) { - stages.push_back(std::unique_ptr( - new AddOpsRewriteStage(ctx, ctx_ext))); - } - if (options_.hoist_common_factor_out_of_aggregation) { - stages.push_back(std::unique_ptr( - new HoistCommonFactorOutOfAggregation(ctx, ctx_ext))); - } - if (options_.remove_identity_transpose) { - stages.push_back(std::unique_ptr( - new RemoveIdentityTranspose(ctx, ctx_ext))); - } - if (options_.remove_redundant_bitcast) { - stages.push_back(std::unique_ptr( - new RemoveRedundantBitcastStage(ctx, ctx_ext))); - } - if (options_.remove_redundant_cast) { - stages.push_back(std::unique_ptr( - new RemoveRedundantCastStage(ctx, ctx_ext))); - } - if (options_.remove_negation) { - stages.push_back(std::unique_ptr( - new RemoveNegationStage(ctx, ctx_ext))); - } - - VLOG(1) << "Simplify arithmetic ops using " << stages.size() + // Stop pipeline after first stage returning non-empty simplified tensor name. + const auto stop = [](const string& result) { return !result.empty(); }; + GraphOptimizerStagePipeline pipeline(stop); + + if (options_.combine_add_to_addn) + pipeline.AddStage(ctx, ctx_ext); + if (options_.hoist_common_factor_out_of_aggregation) + pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_identity_transpose) + pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_redundant_bitcast) + pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_redundant_cast) + pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_negation) + pipeline.AddStage(ctx, ctx_ext); + + VLOG(1) << "Simplify arithmetic ops using " << pipeline.NumStages() << " arithmetic optimization stages"; while (!nodes_to_simplify.Empty()) { @@ -1707,22 +1697,13 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { } // if it was not simplified try to run it through all configured stages - if (simplified_tensor.empty()) { - for (auto& stage : stages) { - if (stage->IsSupported(node)) { - TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor)); - if (!simplified_tensor.empty()) { - break; - } - } + if (!stop(simplified_tensor)) { + bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor); + if (!optimized) { + continue; } } - // if it's still empty go to the next Node - if (simplified_tensor.empty()) { - continue; - } - // re-wire consumers of an old node to the new one if (NodeName(simplified_tensor) != node->name()) { // Always consider simplified_tensor for further optimizations. diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index be95c00..8d3e965 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -117,6 +117,9 @@ class GraphOptimizerStage { : optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {} virtual ~GraphOptimizerStage() = default; + const string& stage_name() const { return stage_name_; } + const string& optimizer_name() const { return optimizer_name_; } + // Check if we should try to simplify node. Returning true doesn't // guarantee that node will be simplified. // @@ -179,6 +182,64 @@ class GraphOptimizerStage { const GraphOptimizerContext ctx_; }; +template +class GraphOptimizerStagePipeline { + public: + // Break predicate specifies if a pipeline should stop early, and not pass + // a node to the next registered optimizer stage, typically that should be the + // case when a stage successfully optimized a node, and it wants to yield + // control to the optimizer. + explicit GraphOptimizerStagePipeline( + const std::function break_predicate) + : break_predicate_(break_predicate) {} + + // Add a stage to the pipeline. It should be called with the arguments for the + // stage constructor: + // + // pipeline.AddStage(constructor_arg1, constructor_arg2); + // + // Returns a reference to the added stage. + template + T& AddStage(Args&&... args) { + auto stage = new T(std::forward(args)...); + stages_.push_back(std::unique_ptr(stage)); + return *stage; + } + + // Pass a node through all registered optimizer stages, until break predicate + // is true. + // + // Return true, if pipeline exited after a break predicate was evaluated as + // 'true', which typically means that a node was optimized by one of the + // registered stages. + // + // Return false, if node was not optimized by any of registered stages. + bool PassThroughAllStages(NodeDef* node, Result* result) { + for (auto& stage : stages_) { + if (stage->IsSupported(node)) { + const Status stage_status = stage->TrySimplify(node, result); + // Each stage must be "error safe" (just like exception safe). In + // case of any error it must leave optimized graph unmodified. + if (!stage_status.ok()) { + LOG(WARNING) << "Failed to run optimizer " << stage->optimizer_name() + << ", stage " << stage->stage_name() + << ". Error: " << stage_status.error_message(); + } + if (break_predicate_(*result)) return true; + } + } + return false; + } + + std::size_t NumStages() { return stages_.size(); } + + private: + std::vector>> stages_; + std::function break_predicate_; + + TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizerStagePipeline); +}; + } // end namespace grappler } // end namespace tensorflow -- 2.7.4