const DataType type = node.attr().at("T").type();
NodeDef* out[2];
- bool created_const = false;
for (int j = 0; j < 2; ++j) {
int reduction_indices = reduce_dims[j].size();
Tensor value(type, TensorShape({reduction_indices}));
AddControlDependency(node.name(), graph_, node_map_.get());
*out[j]->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
- created_const = true;
}
}
- if (created_const) {
- const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
- for (NodeDef* output : outputs) {
- for (int k = 0; k < output->input_size(); ++k) {
- int port;
- string node_name = ParseNodeName(output->input(k), &port);
- if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
- *output->mutable_input(k) = out[port]->name();
- node_map_->UpdateInput(output->name(), node_name, out[port]->name());
- }
+ const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
+ for (NodeDef* output : outputs) {
+ for (int k = 0; k < output->input_size(); ++k) {
+ int port;
+ string node_name = ParseNodeName(output->input(k), &port);
+ if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
+ *output->mutable_input(k) = out[port]->name();
+ node_map_->UpdateInput(output->name(), node_name, out[port]->name());
}
}
}