],
visibility = [
"//tensorflow/compiler:__subpackages__",
- "//tensorflow/core/grappler:__subpackages__",
"//tensorflow/core/profiler:__subpackages__",
],
deps = [":lib_internal"],
":graph_rewriter",
":static_schedule",
"//tensorflow/core:framework",
- "//tensorflow/core:regexp_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/grappler/utils/traversal.h"
-#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
}
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
- const string& recomputation_targets_name_regexp,
+ const string& recomputation_targets_name_prefix,
GraphDef* graph, const GrapplerItem& item) {
if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS &&
optimization_level != RewriterConfig::HEURISTICS &&
for (const auto& feed : item.feed) {
feeds.insert(NodeName(feed.first));
}
- RE2 recomputation_targets_re(recomputation_targets_name_regexp);
std::function<bool(const NodeDef&)> is_target =
- [&recomputation_targets_re](const NodeDef& node) {
- // Nodes whose inputs we may want to recompute. This does a prefix
- // regexp match, and typically one sets regexp="gradients/" meaning
- // it will match all node names with scope beginning with "gradients/".
- // If used within scopes, one may want to set regexp="(.+/)?gradients/".
+ [&recomputation_targets_name_prefix](const NodeDef& node) {
+ // Nodes whose inputs we may want to recompute. Typically targets will
+ // be gradients (recomputation_targets_name_prefix="gradients/"),
+ // although the prefix is configurable since gradients may be created
+ // in a name scope.
// TODO(allenl): Use a static schedule
// (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
// whose outputs will sit around for a while.
- bool match = recomputation_targets_re.Match(
- node.name(), 0, node.name().size(), RE2::ANCHOR_START, nullptr, 0);
- return match;
+ return node.name().find(recomputation_targets_name_prefix) == 0;
};
if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
*optimized_graph = item.graph;
RecomputationRewritingPass(optimization_level_,
- recomputation_targets_name_regexp_,
+ recomputation_targets_name_prefix_,
optimized_graph, item);
GrapplerItem optimized_item(item, std::move(*optimized_graph));
public:
// optimization_level: Controls the level of autonomy for the memory
// optimizer. See RewriterConfig::memory_optimization.
- // recomputation_targets_name_regxp: Name regxp for potential outputs of
+ // recomputation_targets_name_prefix: Name prefix for potential outputs of
// recomputations. See
- // RewriterConfig::memory_optimizer_target_node_name_regxp.
+ // RewriterConfig::memory_optimizer_target_node_name_prefix.
explicit MemoryOptimizer(
RewriterConfig::MemOptType optimization_level,
- const string& recomputation_targets_name_regexp = "gradients/")
+ const string& recomputation_targets_name_prefix = "gradients/")
: optimization_level_(optimization_level),
- recomputation_targets_name_regexp_(recomputation_targets_name_regexp) {}
+ recomputation_targets_name_prefix_(recomputation_targets_name_prefix) {}
~MemoryOptimizer() override {}
string name() const override { return "memory_optimizer"; };
private:
RewriterConfig::MemOptType optimization_level_;
- string recomputation_targets_name_regexp_;
+ string recomputation_targets_name_prefix_;
};
} // end namespace grappler
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
- if (cfg_.memory_optimizer_target_node_name_regexp().empty()) {
+ if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
// Use the default target node name prefix "gradients/"
new MemoryOptimizer(cfg_.memory_optimization())));
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
cfg_.memory_optimization(),
- cfg_.memory_optimizer_target_node_name_regexp())));
+ cfg_.memory_optimizer_target_node_name_prefix())));
}
}
if (cfg_.auto_parallel().enable()) {
// effect on manually requested memory optimization passes in the optimizers
// field.
MemOptType memory_optimization = 4;
- // A regexp for node names which are valid outputs of recomputations. Inputs
- // to nodes that match this regexp may be recomputed (subject either to manual
+ // The prefix for nodes which are valid outputs of recomputations. Inputs to
+ // nodes with this name prefix may be recomputed (subject either to manual
// annotation of those input nodes or to manual annotation and heuristics
- // depending on memory_optimization), but the nodes themselves will not be
- // recomputed. This is a prefix match, meaning it matches any node name that
- // contains a prefix that matches this regexp. Defaults to "gradients/" if
- // not provided, but can be changed if used within scopes.
- string memory_optimizer_target_node_name_regexp = 6;
+ // depending on memory_optimization), but the prefixed nodes themselves will
+ // not be recomputed. Typically this will be "gradients/", indicating that
+ // activations from the forward pass of a graph may be recomputed as inputs to
+ // gradients, but may be adjusted if gradients are inside a name scope or if
+ // inputs to non-gradients should be recomputed. Defaults to "gradients/" if
+ // empty or not set.
+ string memory_optimizer_target_node_name_prefix = 6;
// Configures AutoParallel optimization passes either through the
// meta-optimizer or when manually specified through the optimizers field.
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS,
- memory_optimizer_target_node_name_regexp='optimizer/gradients/'),
- original_metagraph)
- self.assertGreater(
- len(rewritten_graph_def.node),
- len(original_metagraph.graph_def.node))
- self.assertEqual(
- 0,
- len([node for node in original_metagraph.graph_def.node
- if 'Recomputed/' in node.name]))
- self.assertEqual(
- 20, # Two per layer
- len([node for node in rewritten_graph_def.node
- if 'Recomputed/' in node.name]))
-
- def testRewritingNameScopedGradientNamesRegexp(self):
- """Tests that rewriting occurs with non-standard gradient names."""
- (original_metagraph, _, _, _) = self._GetMetaGraph(
- optimizer_scope_name='foo/bar')
- rewritten_graph_def = tf_optimizer.OptimizeGraph(
- rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True,
- constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
- dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
- layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
- arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
- memory_optimization=rewriter_config_pb2.RewriterConfig.
- RECOMPUTATION_HEURISTICS,
- memory_optimizer_target_node_name_regexp='(.+/)gradients/'),
+ memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
original_metagraph)
self.assertGreater(
len(rewritten_graph_def.node),