GraphOptimizerStagePipeline to pass through multiple
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:45:05 +0000 (14:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:47:53 +0000 (14:47 -0700)
optimizer stages, skipping stages that return error.

PiperOrigin-RevId: 191650182

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/graph_optimizer_stage.h

index 6e27259..919f23f 100644 (file)
@@ -1667,34 +1667,24 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
                                   &frame_map_);
   const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
 
-  std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
-
-  if (options_.combine_add_to_addn) {
-    stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
-        new AddOpsRewriteStage(ctx, ctx_ext)));
-  }
-  if (options_.hoist_common_factor_out_of_aggregation) {
-    stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
-        new HoistCommonFactorOutOfAggregation(ctx, ctx_ext)));
-  }
-  if (options_.remove_identity_transpose) {
-    stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
-        new RemoveIdentityTranspose(ctx, ctx_ext)));
-  }
-  if (options_.remove_redundant_bitcast) {
-    stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
-        new RemoveRedundantBitcastStage(ctx, ctx_ext)));
-  }
-  if (options_.remove_redundant_cast) {
-    stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
-        new RemoveRedundantCastStage(ctx, ctx_ext)));
-  }
-  if (options_.remove_negation) {
-    stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
-        new RemoveNegationStage(ctx, ctx_ext)));
-  }
-
-  VLOG(1) << "Simplify arithmetic ops using " << stages.size()
+  // Stop pipeline after first stage returning non-empty simplified tensor name.
+  const auto stop = [](const string& result) { return !result.empty(); };
+  GraphOptimizerStagePipeline<string> pipeline(stop);
+
+  if (options_.combine_add_to_addn)
+    pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
+  if (options_.hoist_common_factor_out_of_aggregation)
+    pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
+  if (options_.remove_identity_transpose)
+    pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
+  if (options_.remove_redundant_bitcast)
+    pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
+  if (options_.remove_redundant_cast)
+    pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
+  if (options_.remove_negation)
+    pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
+
+  VLOG(1) << "Simplify arithmetic ops using " << pipeline.NumStages()
           << " arithmetic optimization stages";
 
   while (!nodes_to_simplify.Empty()) {
@@ -1707,22 +1697,13 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
     }
 
     // if it was not simplified try to run it through all configured stages
-    if (simplified_tensor.empty()) {
-      for (auto& stage : stages) {
-        if (stage->IsSupported(node)) {
-          TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor));
-          if (!simplified_tensor.empty()) {
-            break;
-          }
-        }
+    if (!stop(simplified_tensor)) {
+      bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
+      if (!optimized) {
+        continue;
       }
     }
 
-    // if it's still empty go to the next Node
-    if (simplified_tensor.empty()) {
-      continue;
-    }
-
     // re-wire consumers of an old node to the new one
     if (NodeName(simplified_tensor) != node->name()) {
       // Always consider simplified_tensor for further optimizations.
index be95c00..8d3e965 100644 (file)
@@ -117,6 +117,9 @@ class GraphOptimizerStage {
       : optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {}
   virtual ~GraphOptimizerStage() = default;
 
+  const string& stage_name() const { return stage_name_; }
+  const string& optimizer_name() const { return optimizer_name_; }
+
   // Check if we should try to simplify node. Returning true doesn't
   // guarantee that node will be simplified.
   //
@@ -179,6 +182,64 @@ class GraphOptimizerStage {
   const GraphOptimizerContext ctx_;
 };
 
+template <typename Result>
+class GraphOptimizerStagePipeline {
+ public:
+  // Break predicate specifies if a pipeline should stop early, and not pass
+  // a node to the next registered optimizer stage, typically that should be the
+  // case when a stage successfully optimized a node, and it wants to yield
+  // control to the optimizer.
+  explicit GraphOptimizerStagePipeline(
+      const std::function<bool(const Result&)> break_predicate)
+      : break_predicate_(break_predicate) {}
+
+  // Add a stage to the pipeline. It should be called with the arguments for the
+  // stage constructor:
+  //
+  //   pipeline.AddStage<FooStage>(constructor_arg1, constructor_arg2);
+  //
+  // Returns a reference to the added stage.
+  template <typename T, typename... Args>
+  T& AddStage(Args&&... args) {
+    auto stage = new T(std::forward<Args>(args)...);
+    stages_.push_back(std::unique_ptr<T>(stage));
+    return *stage;
+  }
+
+  // Pass a node through all registered optimizer stages, until break predicate
+  // is true.
+  //
+  // Return true, if pipeline exited after a break predicate was evaluated as
+  // 'true', which typically means that a node was optimized by one of the
+  // registered stages.
+  //
+  // Return false, if node was not optimized by any of registered stages.
+  bool PassThroughAllStages(NodeDef* node, Result* result) {
+    for (auto& stage : stages_) {
+      if (stage->IsSupported(node)) {
+        const Status stage_status = stage->TrySimplify(node, result);
+        // Each stage must be "error safe" (just like exception safe). In
+        // case of any error it must leave optimized graph unmodified.
+        if (!stage_status.ok()) {
+          LOG(WARNING) << "Failed to run optimizer " << stage->optimizer_name()
+                       << ", stage " << stage->stage_name()
+                       << ". Error: " << stage_status.error_message();
+        }
+        if (break_predicate_(*result)) return true;
+      }
+    }
+    return false;
+  }
+
+  std::size_t NumStages() { return stages_.size(); }
+
+ private:
+  std::vector<std::unique_ptr<GraphOptimizerStage<Result>>> stages_;
+  std::function<bool(const Result&)> break_predicate_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizerStagePipeline);
+};
+
 }  // end namespace grappler
 }  // end namespace tensorflow