&frame_map_);
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
- std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
-
- if (options_.combine_add_to_addn) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new AddOpsRewriteStage(ctx, ctx_ext)));
- }
- if (options_.hoist_common_factor_out_of_aggregation) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new HoistCommonFactorOutOfAggregation(ctx, ctx_ext)));
- }
- if (options_.remove_identity_transpose) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new RemoveIdentityTranspose(ctx, ctx_ext)));
- }
- if (options_.remove_redundant_bitcast) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new RemoveRedundantBitcastStage(ctx, ctx_ext)));
- }
- if (options_.remove_redundant_cast) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new RemoveRedundantCastStage(ctx, ctx_ext)));
- }
- if (options_.remove_negation) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new RemoveNegationStage(ctx, ctx_ext)));
- }
-
- VLOG(1) << "Simplify arithmetic ops using " << stages.size()
+ // Stop pipeline after first stage returning non-empty simplified tensor name.
+ const auto stop = [](const string& result) { return !result.empty(); };
+ GraphOptimizerStagePipeline<string> pipeline(stop);
+
+ if (options_.combine_add_to_addn)
+ pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
+ if (options_.hoist_common_factor_out_of_aggregation)
+ pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
+ if (options_.remove_identity_transpose)
+ pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
+ if (options_.remove_redundant_bitcast)
+ pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
+ if (options_.remove_redundant_cast)
+ pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
+ if (options_.remove_negation)
+ pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
+
+ VLOG(1) << "Simplify arithmetic ops using " << pipeline.NumStages()
<< " arithmetic optimization stages";
while (!nodes_to_simplify.Empty()) {
}
// if it was not simplified try to run it through all configured stages
- if (simplified_tensor.empty()) {
- for (auto& stage : stages) {
- if (stage->IsSupported(node)) {
- TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor));
- if (!simplified_tensor.empty()) {
- break;
- }
- }
+ if (!stop(simplified_tensor)) {
+ bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
+ if (!optimized) {
+ continue;
}
}
- // if it's still empty go to the next Node
- if (simplified_tensor.empty()) {
- continue;
- }
-
// re-wire consumers of an old node to the new one
if (NodeName(simplified_tensor) != node->name()) {
// Always consider simplified_tensor for further optimizations.
: optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {}
virtual ~GraphOptimizerStage() = default;
+ const string& stage_name() const { return stage_name_; }
+ const string& optimizer_name() const { return optimizer_name_; }
+
// Check if we should try to simplify node. Returning true doesn't
// guarantee that node will be simplified.
//
const GraphOptimizerContext ctx_;
};
+template <typename Result>
+class GraphOptimizerStagePipeline {
+ public:
+ // Break predicate specifies if a pipeline should stop early, and not pass
+ // a node to the next registered optimizer stage, typically that should be the
+ // case when a stage successfully optimized a node, and it wants to yield
+ // control to the optimizer.
+ explicit GraphOptimizerStagePipeline(
+ const std::function<bool(const Result&)> break_predicate)
+ : break_predicate_(break_predicate) {}
+
+ // Add a stage to the pipeline. It should be called with the arguments for the
+ // stage constructor:
+ //
+ // pipeline.AddStage<FooStage>(constructor_arg1, constructor_arg2);
+ //
+ // Returns a reference to the added stage.
+ template <typename T, typename... Args>
+ T& AddStage(Args&&... args) {
+ auto stage = new T(std::forward<Args>(args)...);
+ stages_.push_back(std::unique_ptr<T>(stage));
+ return *stage;
+ }
+
+ // Pass a node through all registered optimizer stages, until break predicate
+ // is true.
+ //
+ // Return true, if pipeline exited after a break predicate was evaluated as
+ // 'true', which typically means that a node was optimized by one of the
+ // registered stages.
+ //
+ // Return false, if node was not optimized by any of registered stages.
+ bool PassThroughAllStages(NodeDef* node, Result* result) {
+ for (auto& stage : stages_) {
+ if (stage->IsSupported(node)) {
+ const Status stage_status = stage->TrySimplify(node, result);
+ // Each stage must be "error safe" (just like exception safe). In
+ // case of any error it must leave optimized graph unmodified.
+ if (!stage_status.ok()) {
+ LOG(WARNING) << "Failed to run optimizer " << stage->optimizer_name()
+ << ", stage " << stage->stage_name()
+ << ". Error: " << stage_status.error_message();
+ }
+ if (break_predicate_(*result)) return true;
+ }
+ }
+ return false;
+ }
+
+ std::size_t NumStages() { return stages_.size(); }
+
+ private:
+ std::vector<std::unique_ptr<GraphOptimizerStage<Result>>> stages_;
+ std::function<bool(const Result&)> break_predicate_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizerStagePipeline);
+};
+
} // end namespace grappler
} // end namespace tensorflow