void DependencyOptimizer::OptimizeNode(int node_idx,
SetVector<int>* nodes_to_simplify,
std::set<int>* nodes_to_delete) {
- const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
NodeDef* node = optimized_graph_->mutable_node(node_idx);
const bool is_noop = IsNoOp(*node);
const bool is_identity = IsIdentity(*node);
// y --^> | | --^> b /\ +---+
// +----------+ y --^> b
- if (is_noop || (is_identity && is_aggressive)) {
+ if (is_noop || is_identity) {
const auto& output_node_set = node_map_->GetOutputs(node_name);
const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
output_node_set.end());
item.fetch.push_back("neg1");
item.fetch.push_back("neg2");
- DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ DependencyOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"a_a", "a_b", "a_c", "a_d", "b_a", "c_a", "c_b"};
- DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ DependencyOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
item.fetch.push_back("or0");
item.fetch.push_back("or1");
item.fetch.push_back("or2");
- DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ DependencyOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch.push_back("neg2");
- DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ DependencyOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
item.fetch.push_back("id2");
item.fetch.push_back("fetch");
- DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ DependencyOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
grappler_item)
self.assertTrue(run_time > 0)
- self.assertEqual(len(op_perfs), 9)
+ self.assertEqual(len(op_perfs), 7)
self.assertTrue(step_stats.dev_stats)
def testNoDetailedStats(self):
disable_detailed_stats=False, disable_timeline=False) as gcluster:
op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item)
self.assertTrue(run_time > 0)
- self.assertEqual(len(op_perfs), 9)
+ self.assertEqual(len(op_perfs), 7)
self.assertTrue(step_stats.dev_stats)
def testAvailableOps(self):