Fix regression caused by cl/191020868: Re-use materialized shapes for other broadcast...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Apr 2018 19:18:34 +0000 (12:18 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 19:21:29 +0000 (12:21 -0700)
PiperOrigin-RevId: 191779263

tensorflow/core/grappler/optimizers/constant_folding.cc

index d941a0b..2f1b9e4 100644 (file)
@@ -552,7 +552,6 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
 
   const DataType type = node.attr().at("T").type();
   NodeDef* out[2];
-  bool created_const = false;
   for (int j = 0; j < 2; ++j) {
     int reduction_indices = reduce_dims[j].size();
     Tensor value(type, TensorShape({reduction_indices}));
@@ -576,20 +575,17 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
           AddControlDependency(node.name(), graph_, node_map_.get());
       *out[j]->add_input() = ctrl_dep;
       node_map_->AddOutput(NodeName(ctrl_dep), const_name);
-      created_const = true;
     }
   }
 
-  if (created_const) {
-    const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
-    for (NodeDef* output : outputs) {
-      for (int k = 0; k < output->input_size(); ++k) {
-        int port;
-        string node_name = ParseNodeName(output->input(k), &port);
-        if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
-          *output->mutable_input(k) = out[port]->name();
-          node_map_->UpdateInput(output->name(), node_name, out[port]->name());
-        }
+  const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
+  for (NodeDef* output : outputs) {
+    for (int k = 0; k < output->input_size(); ++k) {
+      int port;
+      string node_name = ParseNodeName(output->input(k), &port);
+      if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
+        *output->mutable_input(k) = out[port]->name();
+        node_map_->UpdateInput(output->name(), node_name, out[port]->name());
       }
     }
   }