From: A. Unique TensorFlower Date: Thu, 5 Apr 2018 19:18:34 +0000 (-0700) Subject: Fix regression caused by cl/191020868: Re-use materialized shapes for other broadcast... X-Git-Tag: tflite-v0.1.7~16^2^2~147 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=feb8d7b53953826e0d1b4bc68726392ac0ab310b;p=platform%2Fupstream%2Ftensorflow.git Fix regression caused by cl/191020868: Re-use materialized shapes for other broadcast gradient shape nodes. PiperOrigin-RevId: 191779263 --- diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index d941a0b..2f1b9e4 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -552,7 +552,6 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( const DataType type = node.attr().at("T").type(); NodeDef* out[2]; - bool created_const = false; for (int j = 0; j < 2; ++j) { int reduction_indices = reduce_dims[j].size(); Tensor value(type, TensorShape({reduction_indices})); @@ -576,20 +575,17 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( AddControlDependency(node.name(), graph_, node_map_.get()); *out[j]->add_input() = ctrl_dep; node_map_->AddOutput(NodeName(ctrl_dep), const_name); - created_const = true; } } - if (created_const) { - const std::set outputs = node_map_->GetOutputs(node.name()); - for (NodeDef* output : outputs) { - for (int k = 0; k < output->input_size(); ++k) { - int port; - string node_name = ParseNodeName(output->input(k), &port); - if (node_name == node.name() && port >= 0 && port < 2 && out[port]) { - *output->mutable_input(k) = out[port]->name(); - node_map_->UpdateInput(output->name(), node_name, out[port]->name()); - } + const std::set outputs = node_map_->GetOutputs(node.name()); + for (NodeDef* output : outputs) { + for (int k = 0; k < output->input_size(); ++k) { + int port; + string node_name = ParseNodeName(output->input(k), &port); + if (node_name == node.name() && port >= 0 && port < 2 && out[port]) { + *output->mutable_input(k) = out[port]->name(); + node_map_->UpdateInput(output->name(), node_name, out[port]->name()); } } }