}
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
- const string& recomputation_targets_name_prefix,
+ const string& recomputation_targets_name_scope,
GraphDef* graph, const GrapplerItem& item) {
if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS &&
optimization_level != RewriterConfig::HEURISTICS &&
feeds.insert(NodeName(feed.first));
}
std::function<bool(const NodeDef&)> is_target =
- [&recomputation_targets_name_prefix](const NodeDef& node) {
- // Nodes whose inputs we may want to recompute. Typically targets will
- // be gradients (recomputation_targets_name_prefix="gradients/"),
- // although the prefix is configurable since gradients may be created
- // in a name scope.
- // TODO(allenl): Use a static schedule
- // (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
- // whose outputs will sit around for a while.
- return node.name().find(recomputation_targets_name_prefix) == 0;
+ [&recomputation_targets_name_scope](const NodeDef& node) {
+ // Nodes whose inputs we may want to recompute. This matches node names
+ // that contain recomputation_targets_name_scope as a name scope,
+ // meaning it either begins with or contains the name scope.
+ // Defaults to "gradients/" which will match any node names that begins
+ // with "gradients/" or contains "/gradients/".
+ return node.name().find(recomputation_targets_name_scope) == 0 ||
+ node.name().find("/" + recomputation_targets_name_scope) != -1;
};
if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
*optimized_graph = item.graph;
RecomputationRewritingPass(optimization_level_,
- recomputation_targets_name_prefix_,
- optimized_graph, item);
+ recomputation_targets_name_scope_, optimized_graph,
+ item);
GrapplerItem optimized_item(item, std::move(*optimized_graph));
std::unordered_set<string> skip_list;
public:
// optimization_level: Controls the level of autonomy for the memory
// optimizer. See RewriterConfig::memory_optimization.
- // recomputation_targets_name_prefix: Name prefix for potential outputs of
+ // recomputation_targets_name_scope: Name scope for potential outputs of
// recomputations. See
- // RewriterConfig::memory_optimizer_target_node_name_prefix.
+ // RewriterConfig::memory_optimizer_target_node_name_scope.
explicit MemoryOptimizer(
RewriterConfig::MemOptType optimization_level,
- const string& recomputation_targets_name_prefix = "gradients/")
+ const string& recomputation_targets_name_scope = "gradients/")
: optimization_level_(optimization_level),
- recomputation_targets_name_prefix_(recomputation_targets_name_prefix) {}
+ recomputation_targets_name_scope_(recomputation_targets_name_scope) {}
~MemoryOptimizer() override {}
string name() const override { return "memory_optimizer"; };
private:
RewriterConfig::MemOptType optimization_level_;
- string recomputation_targets_name_prefix_;
+ string recomputation_targets_name_scope_;
};
} // end namespace grappler
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
- if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
+ if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
// Use the default target node name prefix "gradients/"
new MemoryOptimizer(cfg_.memory_optimization())));
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
cfg_.memory_optimization(),
- cfg_.memory_optimizer_target_node_name_prefix())));
+ cfg_.memory_optimizer_target_node_name_scope())));
}
}
if (cfg_.auto_parallel().enable()) {
// effect on manually requested memory optimization passes in the optimizers
// field.
MemOptType memory_optimization = 4;
- // The prefix for nodes which are valid outputs of recomputations. Inputs to
- // nodes with this name prefix may be recomputed (subject either to manual
- // annotation of those input nodes or to manual annotation and heuristics
- // depending on memory_optimization), but the prefixed nodes themselves will
- // not be recomputed. Typically this will be "gradients/", indicating that
- // activations from the forward pass of a graph may be recomputed as inputs to
- // gradients, but may be adjusted if gradients are inside a name scope or if
- // inputs to non-gradients should be recomputed. Defaults to "gradients/" if
- // empty or not set.
- string memory_optimizer_target_node_name_prefix = 6;
+ // A node name scope for node names which are valid outputs of recompuations.
+ // Inputs to nodes that match this scope may be recomputed (subject either to
+ // manual annotation of those input nodes or to manual annotation and
+ // heuristics depending on memory_optimization), but the nodes themselves will
+ // not be recomputed. This matches any sub-scopes as well, meaning the scope
+ // can appear not just as a top-level scope. For example, if the value is
+ // "gradients/", the default, it will match node name "gradients/foo",
+ // "foo/gradients/bar", but not "foo_gradients/"
+ string memory_optimizer_target_node_name_scope = 6;
// Configures AutoParallel optimization passes either through the
// meta-optimizer or when manually specified through the optimizers field.
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS,
- memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
+ # Checks that name scope "gradients/" also match sub-scope.
+ memory_optimizer_target_node_name_scope='gradients/'),
original_metagraph)
self.assertGreater(
len(rewritten_graph_def.node),
len([node for node in rewritten_graph_def.node
if 'Recomputed/' in node.name]))
+ def testRewritingNameScopedGradientNamesScope(self):
+ """Tests that rewriting occurs with non-standard gradient names."""
+ (original_metagraph, _, _,
+ _) = self._GetMetaGraph(optimizer_scope_name='foo/bar')
+ rewritten_graph_def = tf_optimizer.OptimizeGraph(
+ rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True,
+ constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ memory_optimization=rewriter_config_pb2.RewriterConfig.
+ RECOMPUTATION_HEURISTICS,
+ # This should not match anything.
+ memory_optimizer_target_node_name_scope='r/gradients/'),
+ original_metagraph)
+ self.assertEqual(
+ len(rewritten_graph_def.node), len(original_metagraph.graph_def.node))
+ self.assertEqual(0,
+ len([
+ node for node in original_metagraph.graph_def.node
+ if 'Recomputed/' in node.name
+ ]))
+ self.assertEqual(0,
+ len([
+ node for node in rewritten_graph_def.node
+ if 'Recomputed/' in node.name
+ ]))
+
def _GetMemoryOptimizerSessionConfig(self):
rewrite_options = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,