Extend the memory optimizations to also support accumulate_n ops
authorBenoit Steiner <bsteiner@google.com>
Mon, 12 Feb 2018 21:25:57 +0000 (13:25 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Feb 2018 21:29:55 +0000 (13:29 -0800)
PiperOrigin-RevId: 185425999

tensorflow/core/grappler/optimizers/memory_optimizer.cc

index ef178adc1408cdd5b63b34c5e36178a7f27eb026..777cc3a79bc1267711cbaad0148a286bbb74266f 100644 (file)
@@ -490,12 +490,12 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
 }
 
 bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
-  // Look for AddN nodes and record input names.
+  // Look for AddN nodes (and equivalent) and record input names.
   GraphView view(&item->graph);
 
   std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
   for (NodeDef& node : *item->graph.mutable_node()) {
-    if (!IsAddN(node)) {
+    if (!IsAddN(node) && node.op() != "AccumulateNV2") {
       continue;
     }
     // There is nothing to gain by optimizing nodes with 2 or fewer inputs.