struct ArithmeticOptimizerContext {
ArithmeticOptimizerContext(
const std::unordered_set<string>* nodes_to_preserve,
- GraphDef* optimized_graph, NodeMap* node_map,
+ GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map,
SetVector<NodeDef*>* nodes_to_simplify)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
node_map(node_map),
+ frame_map(frame_map),
nodes_to_simplify(nodes_to_simplify) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
NodeMap* node_map;
+ FrameMap* frame_map;
SetVector<NodeDef*>* nodes_to_simplify;
};
// Base class for single arithmetic optimization: e.g. Bitcast optimization,
// AddOps optimization, etc...
+// TODO(ezhulenev): extract this class to be reused by other multi-stage
+// graph optimizers (const_folding, dependency_optimizer, etc...)
class ArithmeticOptimizerStage {
public:
- explicit ArithmeticOptimizerStage(ArithmeticOptimizerContext ctx)
- : ctx_(ctx) {}
+ explicit ArithmeticOptimizerStage(const string& name,
+ const ArithmeticOptimizerContext& ctx)
+ : name_(name), ctx_(ctx) {}
virtual ~ArithmeticOptimizerStage() = default;
// Check if we should try to simplify node. Returning true doesn't
string* simplified_node_name) = 0;
protected:
+ struct ScopedNodeName {
+ string scope;
+ string name;
+ };
+
+ const ScopedNodeName ParseScopedNodeName(const string& name) const {
+ auto pos = name.find_last_of("/");
+ if (pos == string::npos) {
+ return {"", name};
+ } else {
+ return {name.substr(0, pos), name.substr(pos + 1)};
+ }
+ }
+
+ // Prefix optimized node name with stage name and rewrite_rule
+ const string OptimizedNodeName(const string& rewrite_rule,
+ const ScopedNodeName& scoped_node_name) const {
+ return MakeOptimizedNodeName(strings::StrCat(name_, "_", rewrite_rule),
+ scoped_node_name);
+ }
+
+ // Prefix optimized node name with stage name and rewrite_rule
+ const string OptimizedNodeName(const string& rewrite_rule,
+ const ScopedNodeName& scoped_node_name,
+ const std::vector<string>& node_names) const {
+ return MakeOptimizedNodeName(strings::StrCat(name_, "_", rewrite_rule),
+ scoped_node_name, node_names);
+ }
+
+ // Prefix optimized node name with stage name
+ const string OptimizedNodeName(const ScopedNodeName& scoped_node_name) const {
+ return MakeOptimizedNodeName(name_, scoped_node_name);
+ }
+
+ // Prefix optimized node name with stage name
+ const string OptimizedNodeName(const ScopedNodeName& scoped_node_name,
+ const std::vector<string>& node_names) const {
+ return MakeOptimizedNodeName(name_, scoped_node_name, node_names);
+ }
+
// Simplification graph rewrite can create additional nodes that are inputs
// to final simplified node, they can be also added to the arithmetic
// optimizer queue for further optimization.
}
}
- ArithmeticOptimizerContext ctx_;
+ NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) {
+ CHECK(node_to_copy != nullptr);
+ CHECK(!ctx_.node_map->NodeExists(name))
+ << "Node " << name << " already exists in a graph";
+ NodeDef* new_node = ctx_.optimized_graph->add_node();
+ *new_node = *node_to_copy;
+ new_node->set_name(name);
+ ctx_.node_map->AddNode(name, new_node);
+ return new_node;
+ }
+
+ NodeDef* AddEmptyNode(const string& name) {
+ CHECK(!ctx_.node_map->NodeExists(name))
+ << "Node " << name << " already exists in a graph";
+ NodeDef* new_node = ctx_.optimized_graph->add_node();
+ new_node->set_name(name);
+ ctx_.node_map->AddNode(name, new_node);
+ return new_node;
+ }
+
+ // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
+ // optimizations will be migrated to stages
+ void AddFrameControlDeps(const NodeDef* old_node,
+ const std::vector<NodeDef*>& new_nodes,
+ const string& source_for_ctrl_dep,
+ const std::vector<NodeDef*>& sinks_for_control_dep) {
+ const auto frame_it = ctx_.frame_map->find(old_node);
+ if (frame_it != ctx_.frame_map->end()) {
+ for (auto node : new_nodes) {
+ ctx_.frame_map->emplace(node, frame_it->second);
+ }
+ if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map);
+ for (auto node : sinks_for_control_dep) {
+ MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph,
+ ctx_.node_map);
+ }
+ }
+ }
+ }
+
+ const string name_;
+ const ArithmeticOptimizerContext ctx_;
+
+ private:
+ // Get a name for a new node obtained by optimizing a single node of the
+ // original graph. The optimized node is placed under the original node scope.
+ //
+ // Node name uniqueness is guaranteed by unique name of an original node in
+ // a same scope.
+ //
+ // Example: MakeOptimizedNodeName("AwesomeRewrite", "a/b/c/Add_1")
+ // Optimized name: "a/b/c/ArithmeticOptimizer/AwesomeRewrite_Add_1"
+ const string MakeOptimizedNodeName(
+ const string& prefix, const ScopedNodeName& scoped_node_name) const {
+ string node_name;
+ strings::StrAppend(&node_name, scoped_node_name.scope);
+ if (!node_name.empty()) strings::StrAppend(&node_name, "/");
+ strings::StrAppend(&node_name, kArithmeticOptimizer, "/", prefix, "_",
+ scoped_node_name.name);
+ return node_name;
+ }
+
+ // Get a name for a new node obtained by optimizing multiple nodes of the
+ // original graph, starting from "root". The optimized node is placed under
+ // the original scope of a "root" node.
+ //
+ // Node name uniqueness is guaranteed by unique name of a "root" node in
+ // a same scope.
+ //
+ // Example:
+ // MakeOptimizedNodeName("AwesomeRewrite", "a/b/Add_AB", ["x/y/Add_XY"])
+ // Optimized name:
+ // "a/b/ArithmeticOptimizer/AwesomeRewrite_Add_AB_Add_XY"
+ const string MakeOptimizedNodeName(
+ const string& prefix, const ScopedNodeName& scoped_node_name,
+ const std::vector<string>& node_names) const {
+ string node_name = MakeOptimizedNodeName(prefix, scoped_node_name);
+ for (const string& optimized : node_names) {
+ auto scoped_node = ParseScopedNodeName(optimized);
+ strings::StrAppend(&node_name, "_", scoped_node.name);
+ }
+ return node_name;
+ }
};
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
// q e
class AddOpsRewriteStage : public ArithmeticOptimizerStage {
public:
- explicit AddOpsRewriteStage(ArithmeticOptimizerContext ctx)
- : ArithmeticOptimizerStage(ctx), rewritten_nodes_() {}
+ explicit AddOpsRewriteStage(const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("AddOpsRewrite", ctx), rewritten_nodes_() {}
~AddOpsRewriteStage() override = default;
AddOpsGroup group;
TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
- if (!group.absorbed_nodes.empty()) {
+ if (!group.absorbed_nodes.empty() && !IsRewritten(group)) {
*simplified_node_name = RewriteAddOpsGroup(group);
}
DrivesControlDependency(*node));
}
+ // Check that optimized group node name doesn't exists. It might happen if
+ // graph optimized multiple times without pruning beween invocations.
+ bool IsRewritten(const AddOpsGroup& group) const {
+ return ctx_.node_map->NodeExists(AddOpsGroupName(group));
+ }
+
// Create an AddOpsGroup with a root in a given node
Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
group->root_node = root_node;
return Status::OK();
}
- const std::pair<string, string> ParseNodeScopeAndName(const string& name) {
- auto pos = name.find_last_of("/");
- if (pos == string::npos) {
- return {"", name};
- } else {
- return {name.substr(0, pos), name.substr(pos + 1)};
- }
- }
-
// New node for AddOpsGroup is added to the same scope as a root_node. All
// absorbed nodes are stripped of their scope, and only names are used in a
// new node name.
//
// Example: AddOpsGroup(root="a/b/c/Add_2", absorbed=["d/Add_1", "e/Add"])
// node_name="a/b/c/AddOpsGroup_Add_2_Add_1_Add
- string AddOpsGroupName(const AddOpsGroup& group) {
+ string AddOpsGroupName(const AddOpsGroup& group) const {
CHECK_NOTNULL(group.root_node);
- string node_name;
- auto root_node = ParseNodeScopeAndName(group.root_node->name());
- auto root_scope = root_node.first;
- auto root_name = root_node.second;
- if (!root_scope.empty()) {
- strings::StrAppend(&node_name, root_scope, "/");
- }
+ auto root = ParseScopedNodeName(group.root_node->name());
- strings::StrAppend(&node_name, kArithmeticOptimizer, "/", "AddOpsGroup_",
- root_name);
- for (const NodeDef* absorbed : group.absorbed_nodes) {
- auto absorbed_node = ParseNodeScopeAndName(absorbed->name());
- strings::StrAppend(&node_name, "_", absorbed_node.second);
- }
- return node_name;
+ std::vector<string> absorbed_node_names(group.absorbed_nodes.size());
+ std::transform(group.absorbed_nodes.begin(), group.absorbed_nodes.end(),
+ absorbed_node_names.begin(),
+ [](const NodeDef* node) { return node->name(); });
+
+ return OptimizedNodeName(root, absorbed_node_names);
}
// Create a new node for a AddOpsGroup and return it's name.
// copy attributes from a root node
DataType dtype = group.root_node->attr().at("T").type();
- // add new node
- NodeDef* added_node = ctx_.optimized_graph->add_node();
- added_node->set_name(node_name);
+ // add new AddN node
+ NodeDef* added_node = AddEmptyNode(node_name);
added_node->set_op("AddN");
added_node->set_device(group.root_node->device());
(*added_node->mutable_attr())["T"].set_type(dtype);
(*added_node->mutable_attr())["N"].set_i(group.inputs.size());
- ctx_.node_map->AddNode(node_name, added_node);
- for (string input : group.inputs) {
+ // all inputs of absorbed nodes are added to the new node
+ for (const string& input : group.inputs) {
ctx_.node_map->AddOutput(input, node_name);
- added_node->add_input(std::move(input));
+ added_node->add_input(input);
}
VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
std::unordered_set<string> rewritten_nodes_;
};
+// Use the commutativity and (left- and right-) distributive property of
+// multiplication over addition to hoist common factors out of aggregate nodes
+// where all the inputs are Mul nodes. This pattern occurs frequently in
+// regularization terms for the gradients during training.
+//
+// For example, we can rewrite an expression of the form:
+// AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
+// to the following:
+// Mul(x, AddN(y1, y2, y3, ... yn))
+class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
+ public:
+ explicit HoistCommonFactorOutOfAggregation(
+ const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("HoistCommonFactor", ctx) {}
+ ~HoistCommonFactorOutOfAggregation() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
+ !IsRewritten(node);
+ }
+
+ Status TrySimplify(const NodeDef* node,
+ string* simplified_node_name) override {
+ CHECK(IsSupported(node));
+
+ std::set<string> common_factors;
+ TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors));
+
+ if (common_factors.size() == 1) {
+ const string& common_factor = *common_factors.begin();
+
+ // Gather up the non-shared factors
+ bool shapes_match = true;
+ std::vector<string> unique_factors;
+ TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor, &shapes_match,
+ &unique_factors));
+
+ if (shapes_match) {
+ NodeDef* input_0;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
+
+ // Use a copy of the first Mul node for the outer multiplication.
+ NodeDef* new_mul_node = AddCopyNode(OuterMulNodeName(node), input_0);
+ // And a copy of aggregation node as one of the inner operands
+ NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
+
+ new_mul_node->set_device(node->device());
+ new_mul_node->set_input(0, common_factor);
+ new_mul_node->set_input(1, new_add_node->name());
+
+ ctx_.node_map->AddOutput(common_factor, new_mul_node->name());
+ ctx_.node_map->AddOutput(new_add_node->name(), new_mul_node->name());
+
+ // Hoist non-shared factors up into the new AddN node.
+ for (int i = 0; i < unique_factors.size(); ++i) {
+ new_add_node->set_input(i, unique_factors[i]);
+ }
+
+ // Add frame dependencies that the original node might have had.
+ AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
+ {new_add_node});
+
+ // optimize new inner aggregation node
+ AddToOptimizationQueue(new_add_node);
+ // do not optimize the same node twice
+ rewritten_nodes_.insert(node->name());
+ *simplified_node_name = new_mul_node->name();
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Get a name for new outer Mul node
+ string OuterMulNodeName(const NodeDef* node) const {
+ auto scoped_node = ParseScopedNodeName(node->name());
+ return OptimizedNodeName("Mul", scoped_node);
+ }
+
+ // Get a name new inner Add node
+ string InnerAddNodeName(const NodeDef* node) const {
+ auto scoped_node = ParseScopedNodeName(node->name());
+ return OptimizedNodeName("Add", scoped_node);
+ }
+
+ // Determine the set of common factors if the input nodes are all Mul nodes.
+ Status GetCommonFactors(const NodeDef* node,
+ std::set<string>* common_factors) const {
+ CHECK(common_factors->empty());
+
+ for (int i = 0; i < node->input_size(); ++i) {
+ if (i > 0 && common_factors->empty()) break;
+ if (IsControlInput(node->input(i))) break;
+
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
+
+ if (!IsMul(*input)) {
+ common_factors->clear();
+ break;
+ }
+
+ std::set<string> factors_i{input->input(0), input->input(1)};
+ if (i == 0) {
+ std::swap(*common_factors, factors_i);
+ } else {
+ std::set<string> intersection;
+ std::set_intersection(
+ factors_i.begin(), factors_i.end(), common_factors->begin(),
+ common_factors->end(),
+ std::inserter(intersection, intersection.begin()));
+ std::swap(*common_factors, intersection);
+ }
+ }
+ return Status::OK();
+ }
+
+ // Gather up the non-shared factors (the y's in the example).
+ // Unless the aggregation is Add, we have to make sure that all the y's
+ // have the same shape since the other aggregation ops do not support
+ // broadcasting.
+ Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
+ bool* shapes_match,
+ std::vector<string>* unique_factors) const {
+ *shapes_match = true;
+ unique_factors->reserve(node->input_size());
+
+ for (int i = 0; i < node->input_size() && shapes_match; ++i) {
+ const string& input = node->input(i);
+ if (IsControlInput(input)) {
+ break;
+ }
+ NodeDef* mul_node;
+ TF_RETURN_IF_ERROR(GetInputNode(input, &mul_node));
+ const int unique_factor_index =
+ mul_node->input(0) == common_factor ? 1 : 0;
+ unique_factors->push_back(mul_node->input(unique_factor_index));
+ if (i > 0 && !IsAdd(*node)) {
+ *shapes_match = ShapesEqual(unique_factors->front(),
+ unique_factors->back(), *ctx_.node_map);
+ }
+ }
+ return Status::OK();
+ }
+
+ bool IsRewritten(const NodeDef* node) const {
+ // if graph rewrite happens in multiple passes without graph pruning between
+ // them, it's possible that rewritten node already exists in a graph
+ return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
+ ctx_.node_map->NodeExists(OuterMulNodeName(node));
+ }
+
+ // keep names of the nodes that were optimized by this stage
+ std::unordered_set<string> rewritten_nodes_;
+};
+
// Removes inverse transpose nodes
class RemoveInverseTranspose : public ArithmeticOptimizerStage {
public:
- explicit RemoveInverseTranspose(ArithmeticOptimizerContext ctx)
- : ArithmeticOptimizerStage(ctx) {}
+ explicit RemoveInverseTranspose(const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("RemoveInverseTranspose", ctx) {}
~RemoveInverseTranspose() override = default;
bool IsSupported(const NodeDef* node) const override {
// 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
public:
- explicit RemoveRedundantBitcastStage(ArithmeticOptimizerContext ctx)
- : ArithmeticOptimizerStage(ctx) {}
+ explicit RemoveRedundantBitcastStage(const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx) {}
~RemoveRedundantBitcastStage() override = default;
bool IsSupported(const NodeDef* node) const override {
// Remove Casts whose source type and destination type are equal.
class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
public:
- explicit RemoveRedundantCastStage(ArithmeticOptimizerContext ctx)
- : ArithmeticOptimizerStage(ctx) {}
+ explicit RemoveRedundantCastStage(const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("RemoveRedundantCast", ctx) {}
~RemoveRedundantCastStage() override = default;
bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
}
}
- // Use the commutativity and (left- and right-) distributive property of
- // multiplication over addition to hoist common factors out of aggregate nodes
- // where all the inputs are Mul nodes. This pattern occurs frequently in
- // regularization terms for the gradients during training.
- // For example, we can rewrite an expression of the form:
- // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
- // to the following:
- // Mul(x, AddN(y1, y2, y3, ... yn))
- if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
- !OptimizedNodeExists(*node, "hoist_add") &&
- !OptimizedNodeExists(*node, "hoist_mul")) {
- // Determine the set of common factors if the input nodes are all Mul nodes.
- std::set<string> common_factors;
- for (int i = 0; i < node->input_size(); ++i) {
- if (i > 0 && common_factors.empty()) {
- break;
- }
- if (IsControlInput(node->input(i))) {
- break;
- }
- const NodeDef* input = node_map_->GetNode(node->input(i));
- if (input->op() == "Mul") {
- std::set<string> factors_i{input->input(0), input->input(1)};
- if (i == 0) {
- std::swap(common_factors, factors_i);
- } else {
- std::set<string> intersection;
- std::set_intersection(
- factors_i.begin(), factors_i.end(), common_factors.begin(),
- common_factors.end(),
- std::inserter(intersection, intersection.begin()));
- std::swap(common_factors, intersection);
- }
- } else {
- common_factors.clear();
- }
- }
- if (common_factors.size() == 1) {
- const string& common_factor = *common_factors.begin();
-
- // Gather up the non-shared factors (the y's in the example).
- // Unless the aggregation is Add, we have to make sure that all the y's
- // have the same shape since the other aggregation ops do not support
- // broadcasting.
- std::vector<string> unique_factors;
- unique_factors.reserve(node->input_size());
- bool shapes_match = true;
- for (int i = 0; i < node->input_size() && shapes_match; ++i) {
- const string& input = node->input(i);
- if (IsControlInput(input)) {
- break;
- }
- const NodeDef* mul_node = node_map_->GetNode(input);
- const int unique_factor_index =
- mul_node->input(0) == common_factor ? 1 : 0;
- unique_factors.push_back(mul_node->input(unique_factor_index));
- if (i > 0 && !IsAdd(*node)) {
- shapes_match = ShapesEqual(unique_factors.front(),
- unique_factors.back(), *node_map_);
- }
- }
-
- if (shapes_match) {
- // 1. Use a copy of the first Mul node for the outer multiplication.
- NodeDef* new_mul_node = AddNode(OptimizedNodeName(*node, "hoist_mul"),
- node_map_->GetNode(node->input(0)));
- NodeDef* new_add_node = AddNode(*node, "hoist_add", /*copy_node=*/true);
- new_mul_node->set_device(node->device());
- new_mul_node->set_input(0, common_factor);
- node_map_->AddOutput(common_factor, new_mul_node->name());
- new_mul_node->set_input(1, new_add_node->name());
- node_map_->AddOutput(new_add_node->name(), new_mul_node->name());
-
- // 2. Hoist non-shared factors up into the new AddN node.
- nodes_to_simplify->PushBack(new_add_node);
- for (int i = 0; i < node->input_size(); ++i) {
- const string& input = node->input(i);
- if (IsControlInput(input)) {
- break;
- }
- new_add_node->set_input(i, unique_factors[i]);
- }
-
- // 3. Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
- {new_add_node});
-
- return new_mul_node->name();
- }
- }
- }
-
// Fold Transpose into matrix multiplication.
if ((node->op() == "MatMul" || node->op() == "SparseMatMul" ||
node->op() == "BatchMatMul") &&
nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
}
- ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- node_map_.get(), &nodes_to_simplify);
+ const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
+ node_map_.get(), &frame_map_,
+ &nodes_to_simplify);
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
stages.push_back(
std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx)));
}
+ if (options_.hoist_common_factor_out_of_aggregation) {
+ stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
+ new HoistCommonFactorOutOfAggregation(ctx)));
+ }
if (options_.remove_inverse_transpose) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new RemoveInverseTranspose(ctx)));
namespace {
+constexpr char kHoistFactorOptimizerMul[] =
+ "ArithmeticOptimizer/HoistCommonFactor_Mul_";
+
+constexpr char kHoistFactorOptimizerAdd[] =
+ "ArithmeticOptimizer/HoistCommonFactor_Add_";
+
+// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation
+string HoistMulName(const string& name) {
+ return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
+}
+
+// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation
+string HoistAddName(const string& name) {
+ return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
+}
+
string OptimizedName(const string& name) {
return AddPrefixToNodeName(name, kArithmeticOptimizer);
}
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
+ // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
+ void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+ GraphDef* output) {
+ TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+ item->graph.Swap(output);
+ TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+ }
+
// TODO(ezhulenev): Make private. After migration to stages each test
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.combine_add_to_addn = false;
+ options.hoist_common_factor_out_of_aggregation = false;
options.remove_inverse_transpose = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
optimizer->options_ = options;
}
+ void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
+ optimizer->options_.combine_add_to_addn = false;
+ }
+
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.combine_add_to_addn = true;
}
+ void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.hoist_common_factor_out_of_aggregation = true;
+ }
+
void EnableOnlyRemoveInverseTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_inverse_transpose = true;
}
ArithmeticOptimizer optimizer;
- DisableAllStages(&optimizer);
+ DisableAddToAddNCombining(&optimizer);
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
- EXPECT_EQ(17, output.node_size());
- // The graph gets optimized to
+ // We expect the following rewrite(s) to occur:
+ //
// Mul(p,
- // Add(Add(Const(2), Const(2)),
- // Add(Const(2), Const(2))))
+ // Add_6(Add_4(Const(2), Const(2)),
+ // Add_5(Const(2), Const(2))))
+ NodeMap node_map(&output);
+
EXPECT_EQ(17, output.node_size());
- for (const auto& node : output.node()) {
- if ("id" == node.name()) {
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ(OptimizedName("Add_6_hoist_mul"), node.input(0));
- } else if (OptimizedName("Add_6_hoist_mul") == node.name()) {
- EXPECT_EQ("Mul", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("Placeholder", node.input(0));
- EXPECT_EQ(OptimizedName("Add_6_hoist_add"), node.input(1));
- } else if (OptimizedName("Add_6_hoist_add") == node.name()) {
- EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ(OptimizedName("Add_4_hoist_add"), node.input(0));
- EXPECT_EQ(OptimizedName("Add_5_hoist_add"), node.input(1));
- EXPECT_EQ("^Placeholder", node.input(2));
- } else if (OptimizedName("Add_4_hoist_add") == node.name()) {
- EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ(OptimizedName("Add_const"), node.input(0));
- EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1));
- EXPECT_EQ("^Placeholder", node.input(2));
- } else if (OptimizedName("Add_5_hoist_add") == node.name()) {
- EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ(OptimizedName("Add_const"), node.input(0));
- EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1));
- EXPECT_EQ("^Placeholder", node.input(2));
- } else if (OptimizedName("Add_const") == node.name()) {
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("^Placeholder", node.input(0));
- } else if (OptimizedName("Add_1_const") == node.name()) {
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("^Placeholder", node.input(0));
- }
- }
+
+ const NodeDef* id_node = node_map.GetNode("id");
+ ASSERT_TRUE(id_node != nullptr);
+ EXPECT_EQ(1, id_node->input_size());
+ EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0));
+
+ const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
+ ASSERT_TRUE(mul_node != nullptr);
+ EXPECT_EQ(2, mul_node->input_size());
+ EXPECT_EQ("Placeholder", mul_node->input(0));
+ EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1));
+
+ const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
+ ASSERT_TRUE(add_6_node != nullptr);
+ EXPECT_EQ(3, add_6_node->input_size());
+ EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
+ EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
+ EXPECT_EQ("^Placeholder", add_6_node->input(2));
+
+ const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
+ ASSERT_TRUE(add_4_node != nullptr);
+ EXPECT_EQ("Add", add_4_node->op());
+ EXPECT_EQ(3, add_4_node->input_size());
+ EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
+ EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1));
+ EXPECT_EQ("^Placeholder", add_4_node->input(2));
+
+ const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
+ ASSERT_TRUE(add_5_node != nullptr);
+ EXPECT_EQ("Add", add_5_node->op());
+ EXPECT_EQ(3, add_5_node->input_size());
+ EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
+ EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1));
+ EXPECT_EQ("^Placeholder", add_5_node->input(2));
+
+ const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
+ ASSERT_TRUE(add_const_node != nullptr);
+ EXPECT_EQ("Const", add_const_node->op());
+ EXPECT_EQ(1, add_const_node->input_size());
+ EXPECT_EQ("^Placeholder", add_const_node->input(0));
+
+ const NodeDef* add_1_const_node =
+ node_map.GetNode(OptimizedName("Add_1_const"));
+ ASSERT_TRUE(add_1_const_node != nullptr);
+ EXPECT_EQ("Const", add_1_const_node->op());
+ EXPECT_EQ(1, add_1_const_node->input_size());
+ EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
}
TEST_F(ArithmeticOptimizerTest, HoistFactor) {
ops::Add(s.WithOpName("add"), mul1, mul2));
GrapplerItem item;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
ArithmeticOptimizer optimizer;
+ EnableOnlyHoistCommonFactor(&optimizer);
+
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // Add Mul
+ // / \ / \
+ // Mul Mul -> x Add
+ // / \ / \ / \
+ // x y1 y2 x y1 y2
+ //
+ // If "root" op is AddN and shapes does not match, this rewrite is not
+ // possible and graph should stay intact.
+ NodeMap node_map(&output);
if (use_addn && !matching_shapes) {
VerifyGraphsMatch(item.graph, output, __LINE__);
} else {
EXPECT_EQ(9, output.node_size());
- const NodeDef& new_add = output.node(8);
- EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name());
- EXPECT_EQ("y1", new_add.input(0));
- EXPECT_EQ("y2", new_add.input(1));
- const NodeDef& new_mul = output.node(7);
- EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name());
- EXPECT_EQ("x", new_mul.input(0));
- EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1));
- const NodeDef& new_id = output.node(6);
- EXPECT_EQ("id", new_id.name());
- EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0));
+
+ const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
+ ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
+ EXPECT_EQ("y1", new_add_node->input(0));
+ EXPECT_EQ("y2", new_add_node->input(1));
+
+ const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
+ ASSERT_TRUE(new_mul_node != nullptr) << "Hoisted Mul node not found";
+ EXPECT_EQ("x", new_mul_node->input(0));
+ EXPECT_EQ(new_add_node->name(), new_mul_node->input(1));
+
+ const NodeDef* id_node = node_map.GetNode("id");
+ ASSERT_TRUE(id_node != nullptr) << "Id node not found";
+ EXPECT_EQ("id", id_node->name());
+ EXPECT_EQ(HoistMulName("add"), id_node->input(0));
}
}
}
NodeMap node_map(&output);
// check add tree was replaced with AddN
- const NodeDef* collapsed_add = CHECK_NOTNULL(
- node_map.GetNode("y/ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+ const NodeDef* collapsed_add =
+ node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
+ ASSERT_TRUE(collapsed_add != nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(3, collapsed_add->input_size());
EXPECT_EQ("c", collapsed_add->input(2));
// check output was re-wired to new node
- const NodeDef* updated_outputs = CHECK_NOTNULL(node_map.GetNode("outputs"));
+ const NodeDef* updated_outputs = node_map.GetNode("outputs");
+ ASSERT_TRUE(updated_outputs != nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
NodeMap node_map(&output);
// check left Add subtree replaced with AddN
- const NodeDef* collapsed_left = CHECK_NOTNULL(
- node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+ const NodeDef* collapsed_left =
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
+ ASSERT_TRUE(collapsed_left != nullptr);
EXPECT_EQ("AddN", collapsed_left->op());
EXPECT_EQ(3, collapsed_left->input_size());
EXPECT_EQ("c", collapsed_left->input(2));
// check right Add subtree replaced with AddN
- const NodeDef* collapsed_right = CHECK_NOTNULL(
- node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_xyz_Add_xy"));
+ const NodeDef* collapsed_right =
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz_Add_xy");
+ ASSERT_TRUE(collapsed_right != nullptr);
EXPECT_EQ("AddN", collapsed_right->op());
EXPECT_EQ(3, collapsed_right->input_size());
EXPECT_EQ("z", collapsed_right->input(2));
// check that Mul inputs re-wired to new Nodes
- const NodeDef* updated_mul = CHECK_NOTNULL(node_map.GetNode("Mul"));
+ const NodeDef* updated_mul = node_map.GetNode("Mul");
+ ASSERT_TRUE(updated_mul != nullptr);
EXPECT_EQ("Mul", updated_mul->op());
EXPECT_EQ(2, updated_mul->input_size());
NodeMap node_map(&output);
// check Add tree replaced with AddN
- const NodeDef* collapsed_add = CHECK_NOTNULL(node_map.GetNode(
- "ArithmeticOptimizer/AddOpsGroup_Add_all_Add_ab_Add_bc"));
+ const NodeDef* collapsed_add = node_map.GetNode(
+ "ArithmeticOptimizer/AddOpsRewrite_Add_all_Add_ab_Add_bc");
+ ASSERT_TRUE(collapsed_add != nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(4, collapsed_add->input_size());